in serving/reagent/serving/core/PytorchActionValueScorer.cpp [11:91]
StringDoubleMap PytorchActionValueScorer::predict(
const DecisionRequest& request,
int modelId,
int snapshotId) {
try {
std::string path =
"/tmp/" + std::to_string(modelId) + "/" + std::to_string(snapshotId);
if (models_.find(path) == models_.end()) {
try {
#ifdef FB_INTERNAL
// First load predictor container, then extract module
std::shared_ptr<caffe2::PyTorchPredictorContainer> pytorchPredictor_;
pytorchPredictor_ =
std::make_shared<caffe2::PyTorchPredictorContainer>(path);
auto module = pytorchPredictor_->getPredictor()->get_module();
#else
// Deserialize the ScriptModule from a file using torch::jit::load().
torch::jit::script::Module module = torch::jit::load(path);
#endif
models_[path] = std::move(module);
} catch (const c10::Error& e) {
LOG(ERROR) << "Error loading the model: " << e.what();
return StringDoubleMap();
}
}
auto model = models_.find(path)->second;
bool discreteActions = !request.actions.names.empty();
StringList actionNames = Operator::getActionNamesFromRequest(request);
std::set<std::string> actionNameSet(actionNames.begin(), actionNames.end());
StringDoubleMap retval;
if (discreteActions) {
int input_size = 1;
for (auto it : request.context_features) {
input_size = std::max(input_size, 1 + std::stoi(it.first));
}
auto input = torch::zeros({1, input_size});
auto inputMask = torch::zeros({1, input_size});
for (auto it : request.context_features) {
VLOG(1) << "FEATURE SCORE: " << it.second;
input[0][std::stoi(it.first)] = it.second;
inputMask[0][std::stoi(it.first)] = 1.0;
}
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
auto stateWithPresence = c10::ivalue::Tuple::create({input, inputMask});
inputs.push_back(stateWithPresence);
auto result = model.forward(inputs);
auto tupleResult = result.toTuple();
auto outputActionNames = tupleResult->elements()[0];
auto outputActionNameList = outputActionNames.toList();
auto actionScores = tupleResult->elements()[1];
auto actionScoresTensor = actionScores.toTensor();
for (int a = 0; a < outputActionNameList.size(); a++) {
std::string scoredActionName =
outputActionNameList.get(a).toStringRef();
if (actionNameSet.find(scoredActionName) == actionNameSet.end()) {
VLOG(1) << "Skipping action that wasn't possible";
continue;
}
VLOG(1) << "SCORING: " << scoredActionName << " -> "
<< actionScoresTensor[0][a].item().to<double>();
retval[scoredActionName] = actionScoresTensor[0][a].item().to<double>();
}
} else {
LOG(FATAL) << "Not supported yet";
}
VLOG(1) << "SCORED " << retval.size() << " ITEMS";
return retval;
} catch (const c10::Error& e) {
LOG(FATAL) << "TORCH ERROR: " << e.what();
} catch (...) {
LOG(FATAL) << "UNKNOWN ERROR";
}
LOG(FATAL) << "Should never get here";
}