in src/c_api/c_predict_api.cc [71:224]
int MXPredCreatePartialOut(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
mx_uint num_output_nodes,
const char** output_keys,
PredictorHandle* out) {
using nnvm::Symbol;
MXAPIPredictor* ret = new MXAPIPredictor();
API_BEGIN();
Symbol sym;
// make sure symbols are registered
{
mx_uint outSize;
const char **outArray;
MXListAllOpNames(&outSize, &outArray);
}
// load in the symbol.
{
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
}
// looks likely to output the internal results
if (num_output_nodes != 0) {
Symbol internal = sym.GetInternals();
std::vector<std::string> all_out = internal.ListOutputNames();
std::vector<Symbol> out_syms(num_output_nodes);
for (mx_uint i = 0; i < num_output_nodes; ++i) {
std::string out_key(output_keys[i]);
out_key += "_output";
for (size_t j = 0; j < all_out.size(); ++j) {
if (all_out[j] == out_key) {
out_syms[i] = internal[j];
break;
}
CHECK_NE(j, all_out.size() - 1) << "didn't find node name: " << out_key;
}
}
sym = nnvm::Symbol::CreateGroup(out_syms);
}
// load the parameters
std::unordered_map<std::string, NDArray> arg_params, aux_params;
{
std::unordered_set<std::string> arg_names, aux_names;
std::vector<std::string> arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names_vec = sym.ListInputNames(Symbol::kAuxiliaryStates);
for (size_t i = 0; i < arg_names_vec.size(); ++i) {
arg_names.insert(arg_names_vec[i]);
}
for (size_t i = 0; i < aux_names_vec.size(); ++i) {
aux_names.insert(aux_names_vec[i]);
}
std::vector<NDArray> data;
std::vector<std::string> names;
dmlc::MemoryFixedSizeStream fi((void*)param_bytes, param_size); // NOLINT(*)
NDArray::Load(&fi, &data, &names);
CHECK_EQ(names.size(), data.size())
<< "Invalid param file format";
for (size_t i = 0; i < names.size(); ++i) {
if (!strncmp(names[i].c_str(), "aux:", 4)) {
std::string name(names[i].c_str() + 4);
if (aux_names.count(name) != 0) {
aux_params[name] = data[i];
}
}
if (!strncmp(names[i].c_str(), "arg:", 4)) {
std::string name(names[i].c_str() + 4);
if (arg_names.count(name) != 0) {
arg_params[name] = data[i];
}
}
}
}
// shape inference and bind
std::unordered_map<std::string, TShape> known_shape;
for (mx_uint i = 0; i < num_input_nodes; ++i) {
known_shape[std::string(input_keys[i])] =
TShape(input_shape_data + input_shape_indptr[i],
input_shape_data + input_shape_indptr[i + 1]);
}
std::vector<std::string> arg_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
std::vector<TShape> out_shapes(sym.ListOutputNames().size());
std::vector<TShape> aux_shapes(aux_names.size());
std::vector<TShape> arg_shapes;
for (size_t i = 0; i < arg_names.size(); ++i) {
std::string key = arg_names[i];
ret->key2arg[key] = i;
}
try {
std::vector<TShape> in_shapes;
for (std::string key : sym.ListInputNames(Symbol::kAll)) {
if (known_shape.count(key) != 0) {
in_shapes.push_back(known_shape[key]);
} else {
in_shapes.push_back(TShape());
}
}
nnvm::Graph g; g.outputs = sym.outputs;
g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__");
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
CHECK(infer_complete)
<< "The shape information of is not enough to get the shapes";
CopyAttr(g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"),
&arg_shapes, &out_shapes, &aux_shapes);
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
std::vector<NDArray> arg_arrays, aux_arrays;
for (size_t i = 0; i < arg_shapes.size(); ++i) {
NDArray nd = NDArray(arg_shapes[i], ctx);
if (arg_params.count(arg_names[i]) != 0) {
CopyFromTo(arg_params[arg_names[i]], &nd);
}
arg_arrays.push_back(nd);
}
for (size_t i = 0; i < aux_shapes.size(); ++i) {
NDArray nd = NDArray(aux_shapes[i], ctx);
if (aux_params.count(aux_names[i]) != 0) {
CopyFromTo(aux_params[aux_names[i]], &nd);
}
aux_arrays.push_back(nd);
}
ret->arg_arrays = arg_arrays;
// bind
{
std::map<std::string, Context> ctx_map;
std::vector<NDArray> grad_store(arg_arrays.size());
std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
ret->exec.reset(Executor::Bind(sym, ctx, ctx_map,
arg_arrays,
grad_store, grad_req,
aux_arrays));
ret->out_shapes = out_shapes;
ret->out_arrays = ret->exec->outputs();
}
*out = ret;
API_END_HANDLE_ERROR(delete ret);
}