int MXPredCreatePartialOut()

in src/c_api/c_predict_api.cc [71:224]


int MXPredCreatePartialOut(const char* symbol_json_str,
                           const void* param_bytes,
                           int param_size,
                           int dev_type, int dev_id,
                           mx_uint num_input_nodes,
                           const char** input_keys,
                           const mx_uint* input_shape_indptr,
                           const mx_uint* input_shape_data,
                           mx_uint num_output_nodes,
                           const char** output_keys,
                           PredictorHandle* out) {
  using nnvm::Symbol;

  MXAPIPredictor* ret = new MXAPIPredictor();
  API_BEGIN();
  Symbol sym;
  // make sure symbols are registered
  {
  mx_uint outSize;
  const char **outArray;
  MXListAllOpNames(&outSize, &outArray);
  }
  // load in the symbol.
  {
    nnvm::Graph g;
    g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
    sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
  }
  // looks likely to output the internal results
  if (num_output_nodes != 0) {
    Symbol internal = sym.GetInternals();
    std::vector<std::string> all_out = internal.ListOutputNames();
    std::vector<Symbol> out_syms(num_output_nodes);
    for (mx_uint i = 0; i < num_output_nodes; ++i) {
      std::string out_key(output_keys[i]);
      out_key += "_output";
      for (size_t j = 0; j < all_out.size(); ++j) {
        if (all_out[j] == out_key) {
          out_syms[i] = internal[j];
          break;
        }
        CHECK_NE(j, all_out.size() - 1) << "didn't find node name: " << out_key;
      }
    }
    sym = nnvm::Symbol::CreateGroup(out_syms);
  }

  // load the parameters
  std::unordered_map<std::string, NDArray> arg_params, aux_params;
  {
    std::unordered_set<std::string> arg_names, aux_names;
    std::vector<std::string> arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs);
    std::vector<std::string> aux_names_vec = sym.ListInputNames(Symbol::kAuxiliaryStates);
    for (size_t i = 0; i < arg_names_vec.size(); ++i) {
      arg_names.insert(arg_names_vec[i]);
    }
    for (size_t i = 0; i < aux_names_vec.size(); ++i) {
      aux_names.insert(aux_names_vec[i]);
    }
    std::vector<NDArray> data;
    std::vector<std::string> names;
    dmlc::MemoryFixedSizeStream fi((void*)param_bytes, param_size);  // NOLINT(*)
    NDArray::Load(&fi, &data, &names);
    CHECK_EQ(names.size(), data.size())
        << "Invalid param file format";
    for (size_t i = 0; i < names.size(); ++i) {
      if (!strncmp(names[i].c_str(), "aux:", 4)) {
        std::string name(names[i].c_str() + 4);
        if (aux_names.count(name) != 0) {
          aux_params[name] = data[i];
        }
      }
      if (!strncmp(names[i].c_str(), "arg:", 4)) {
        std::string name(names[i].c_str() + 4);
        if (arg_names.count(name) != 0) {
          arg_params[name] = data[i];
        }
      }
    }
  }

  // shape inference and bind
  std::unordered_map<std::string, TShape> known_shape;
  for (mx_uint i = 0; i < num_input_nodes; ++i) {
    known_shape[std::string(input_keys[i])] =
        TShape(input_shape_data + input_shape_indptr[i],
               input_shape_data + input_shape_indptr[i + 1]);
  }
  std::vector<std::string> arg_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
  std::vector<std::string> aux_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
  std::vector<TShape> out_shapes(sym.ListOutputNames().size());
  std::vector<TShape> aux_shapes(aux_names.size());
  std::vector<TShape> arg_shapes;
  for (size_t i = 0; i < arg_names.size(); ++i) {
    std::string key = arg_names[i];
    ret->key2arg[key] = i;
  }

  try {
    std::vector<TShape> in_shapes;
    for (std::string key : sym.ListInputNames(Symbol::kAll)) {
      if (known_shape.count(key) != 0) {
        in_shapes.push_back(known_shape[key]);
      } else {
        in_shapes.push_back(TShape());
      }
    }
    nnvm::Graph g; g.outputs = sym.outputs;
    g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__");
    bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
    CHECK(infer_complete)
      << "The shape information of is not enough to get the shapes";
    CopyAttr(g.indexed_graph(),
             g.GetAttr<nnvm::ShapeVector>("shape"),
             &arg_shapes, &out_shapes, &aux_shapes);
  } catch (const mxnet::op::InferShapeError &err) {
    throw dmlc::Error(err.msg);
  }

  Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);

  std::vector<NDArray> arg_arrays, aux_arrays;
  for (size_t i = 0; i < arg_shapes.size(); ++i) {
    NDArray nd = NDArray(arg_shapes[i], ctx);
    if (arg_params.count(arg_names[i]) != 0) {
      CopyFromTo(arg_params[arg_names[i]], &nd);
    }
    arg_arrays.push_back(nd);
  }
  for (size_t i = 0; i < aux_shapes.size(); ++i) {
    NDArray nd = NDArray(aux_shapes[i], ctx);
    if (aux_params.count(aux_names[i]) != 0) {
      CopyFromTo(aux_params[aux_names[i]], &nd);
    }
    aux_arrays.push_back(nd);
  }
  ret->arg_arrays = arg_arrays;
  // bind
  {
    std::map<std::string, Context> ctx_map;
    std::vector<NDArray> grad_store(arg_arrays.size());
    std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);


    ret->exec.reset(Executor::Bind(sym, ctx, ctx_map,
                                   arg_arrays,
                                   grad_store, grad_req,
                                   aux_arrays));
    ret->out_shapes = out_shapes;
    ret->out_arrays = ret->exec->outputs();
  }
  *out = ret;
  API_END_HANDLE_ERROR(delete ret);
}