napi_value TFJSBackend::RunSavedModel()

in tfjs-node/binding/tfjs_backend.cc [1011:1181]


napi_value TFJSBackend::RunSavedModel(napi_env env,
                                      napi_value savedmodel_id_value,
                                      napi_value input_tensor_ids,
                                      napi_value input_op_names_value,
                                      napi_value output_op_names_value) {
  napi_status nstatus;
  TF_AutoStatus tf_status;

  int32_t savedmodel_id;
  nstatus = napi_get_value_int32(env, savedmodel_id_value, &savedmodel_id);
  ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

  // Get corresponding SavedModel session and graph.
  auto savedmodel_entry = tf_savedmodel_map_.find(savedmodel_id);
  if (savedmodel_entry == tf_savedmodel_map_.end()) {
    NAPI_THROW_ERROR(env, "SavedModel ID not found (savedmodel_id: %d)",
                     savedmodel_id);
    return nullptr;
  }

  std::string input_op_names;
  nstatus = GetStringParam(env, input_op_names_value, input_op_names);
  ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
  std::string output_op_names;
  nstatus = GetStringParam(env, output_op_names_value, output_op_names);
  ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

  // Get input/output op names as vector
  std::vector<const char *> input_op_name_array =
      splitStringByComma(input_op_names);
  std::vector<const char *> output_op_name_array =
      splitStringByComma(output_op_names);

  std::vector<TF_Output> inputs;
  std::vector<TF_Output> outputs;

  uint32_t num_input_ids;
  nstatus = napi_get_array_length(env, input_tensor_ids, &num_input_ids);
  ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

  if (input_op_name_array.size() != num_input_ids) {
    NAPI_THROW_ERROR(env,
                     "Length of input op names (%d) does not match the length "
                     "of input tensors (%d).",
                     input_op_name_array.size(), num_input_ids);
    return nullptr;
  }

  std::vector<TF_Tensor *> input_values;

  for (uint32_t i = 0; i < num_input_ids; i++) {
    napi_value cur_input_id;
    nstatus = napi_get_element(env, input_tensor_ids, i, &cur_input_id);
    ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

    int32_t cur_input_tensor_id;
    nstatus = napi_get_value_int32(env, cur_input_id, &cur_input_tensor_id);
    ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

    // Find input tensor based on tensor id.
    auto tensor_entry = tfe_handle_map_.find(cur_input_tensor_id);
    if (tensor_entry == tfe_handle_map_.end()) {
      NAPI_THROW_ERROR(env, "Input Tensor ID not found (tensor_id: %d)",
                       cur_input_tensor_id);
      return nullptr;
    }
    TF_Tensor *inputTensor =
        TFE_TensorHandleResolve(tensor_entry->second, tf_status.status);

    if (TF_GetCode(tf_status.status) != TF_OK) {
      NAPI_THROW_ERROR(
          env, "Failed to get input tensor (tensor_id: %d) for session.",
          cur_input_tensor_id);
      return nullptr;
    }

    // Add input tensor into input values list.
    input_values.push_back(inputTensor);

    // The item in input_op_name_array is something like "serving_default_x:0".
    // Parse it into input op name and index for provided tensor.
    std::string name(input_op_name_array[i]);
    int index = name.find(":");
    std::string input_op_name = name.substr(0, index);
    std::string input_op_index = name.substr(index + 1);
    int input_tensor_index;
    if (input_op_index.length() == 0) {
      input_tensor_index = 0;
    } else {
      input_tensor_index = atoi(input_op_index.c_str());
    }

    // Add input op into input ops list.
    // TODO(kangyizhang): Store these TF_Operations somewhere so they don't need
    // to be generated  every time.
    TF_Operation *input_op = TF_GraphOperationByName(
        savedmodel_entry->second.second, input_op_name.c_str());
    if (input_op == nullptr) {
      NAPI_THROW_ERROR(env, "Input op name can not be found in the graph.");
      return nullptr;
    }
    TF_Output in = {input_op, input_tensor_index};
    inputs.push_back(in);
  }

  // Add output op into output ops list.
  for (uint32_t i = 0; i < output_op_name_array.size(); i++) {
    // The item in output_op_name_array is something like
    // "StatefulPartitionedCall:0". Parse it into output op name and index.
    std::string name(output_op_name_array[i]);
    int index = name.find(":");
    std::string output_op_name = name.substr(0, index);
    std::string output_op_index = name.substr(index + 1);
    int output_tensor_index;
    if (output_op_index.length() == 0) {
      output_tensor_index = 0;
    } else {
      output_tensor_index = atoi(output_op_index.c_str());
    }

    TF_Operation *output_op = TF_GraphOperationByName(
        savedmodel_entry->second.second, output_op_name.c_str());
    if (output_op == nullptr) {
      NAPI_THROW_ERROR(env, "Output op name can not be found in the graph.");
      return nullptr;
    }
    TF_Output out = {output_op, output_tensor_index};
    outputs.push_back(out);
  }

  std::vector<TF_Tensor *> output_values(outputs.size(), nullptr);

  TF_SessionRun(savedmodel_entry->second.first, nullptr, inputs.data(),
                input_values.data(), num_input_ids, outputs.data(),
                output_values.data(), output_op_name_array.size(), nullptr, 0,
                nullptr, tf_status.status);

  if (TF_GetCode(tf_status.status) != TF_OK) {
    NAPI_THROW_ERROR(env, "Session fail to run with error: %s",
                     TF_Message(tf_status.status));
    return nullptr;
  }

  napi_value output_tensor_infos;
  nstatus = napi_create_array_with_length(env, 1, &output_tensor_infos);
  ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);

  // Generate output tensors for JS.
  for (uint32_t i = 0; i < output_op_name_array.size(); i++) {
    TFE_TensorHandle *tfe_handle =
        TFE_NewTensorHandle(output_values[i], tf_status.status);
    // Deallocate output TF_Tensor in C++.
    TF_DeleteTensor(output_values[i]);

    napi_value tensor_info_value = GenerateOutputTensorInfo(env, tfe_handle);
    // Push into output array
    nstatus = napi_set_element(env, output_tensor_infos, i, tensor_info_value);
    ENSURE_NAPI_OK_RETVAL(env, nstatus, nullptr);
    // delete output op name string
    delete output_op_name_array[i];
  }

  for (uint32_t i = 0; i < num_input_ids; i++) {
    // Deallocate input TF_Tensor in C++.
    TF_DeleteTensor(input_values[i]);
    // delete input op name string
    delete input_op_name_array[i];
  }

  return output_tensor_infos;
}