int MXExecutorSimpleBind()

in src/c_api/c_api_executor.cc [224:540]


int MXExecutorSimpleBind(SymbolHandle symbol_handle,
                         int dev_type,
                         int dev_id,
                         const mx_uint num_g2c_keys,
                         const char** g2c_keys,
                         const int* g2c_dev_types,
                         const int* g2c_dev_ids,
                         const mx_uint provided_grad_req_list_len,
                         const char** provided_grad_req_names,
                         const char** provided_grad_req_types,
                         const mx_uint num_provided_arg_shapes,
                         const char** provided_arg_shape_names,
                         const mx_uint* provided_arg_shape_data,
                         const mx_uint* provided_arg_shape_idx,
                         const mx_uint num_provided_arg_dtypes,
                         const char** provided_arg_dtype_names,
                         const int* provided_arg_dtypes,
                         const mx_uint num_provided_arg_stypes,
                         const char** provided_arg_stype_names,
                         const int* provided_arg_stypes,
                         const mx_uint num_shared_arg_names,
                         const char** shared_arg_name_list,
                         int* shared_buffer_len,
                         const char** shared_buffer_name_list,
                         NDArrayHandle* shared_buffer_handle_list,
                         const char*** updated_shared_buffer_name_list,
                         NDArrayHandle** updated_shared_buffer_handle_list,
                         mx_uint* num_in_args,
                         NDArrayHandle** in_args,
                         NDArrayHandle** arg_grads,
                         mx_uint* num_aux_states,
                         NDArrayHandle** aux_states,
                         ExecutorHandle shared_exec_handle,
                         ExecutorHandle* out) {
  MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
  API_BEGIN();
  nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);

  // get in_arg names
  std::vector<std::string> in_arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
  std::vector<std::string> aux_state_names = sym->ListInputNames(nnvm::Symbol::kAuxiliaryStates);

  // attr_dict for setting up type_dict and arg/aux ctx
  std::unordered_map<std::string, std::unordered_map<std::string, std::string>> attr_dict;
  if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) {
    std::vector<std::tuple<std::string, std::string, std::string>> attrs =
      sym->ListAttrsRecursive();
    attr_dict.reserve(attrs.size());
    for (const auto& tp : attrs) {
      attr_dict[std::get<0>(tp)][std::get<1>(tp)] = std::get<2>(tp);
    }
  }

  // setup arg_dtype_map
  std::unordered_map<std::string, int> arg_dtype_map;
  if (nullptr == provided_arg_dtypes) {  // use attr_dict
    for (const auto& arg_name : in_arg_names) {
      const auto it = attr_dict.find(arg_name);
      if (it == attr_dict.end() || !it->second.count("__dtype__")) {
        arg_dtype_map[arg_name] = mshadow::kFloat32;
      }
    }
  } else {  // use user input type_dict
    // create dtype map for in_args and aux_states
    arg_dtype_map.reserve(num_provided_arg_dtypes);
    for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) {
      arg_dtype_map[provided_arg_dtype_names[i]] = provided_arg_dtypes[i];
    }
  }

  // setup arg_stype_map
  std::unordered_map<std::string, int> arg_stype_map;
  if (nullptr == provided_arg_stypes) {  // use attr_dict
    for (const auto& arg_name : in_arg_names) {
      const auto it = attr_dict.find(arg_name);
      if (it == attr_dict.end() || !it->second.count("__storage_type__")) {
        arg_stype_map[arg_name] = kDefaultStorage;
      }
    }
  } else {  // use user input type_dict
    // create stype map for in_args and aux_states
    arg_stype_map.reserve(num_provided_arg_stypes);
    for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) {
      arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i];
    }
  }

  // create default ctx
  Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
  // create ctx map
  std::map<std::string, Context> ctx_map;
  std::vector<Context> in_arg_ctx_vec(in_arg_names.size(), ctx);
  std::vector<Context> aux_state_ctx_vec(aux_state_names.size(), ctx);
  if (nullptr != g2c_keys) {  // use user input group2ctx dict
    for (mx_uint i = 0; i < num_g2c_keys; ++i) {
      ctx_map[g2c_keys[i]] = Context::Create(
          static_cast<Context::DeviceType>(g2c_dev_types[i]), g2c_dev_ids[i]);
    }

    // initialize in_arg_ctx_vec using group2ctx if there are any
    for (size_t i = 0; i < in_arg_ctx_vec.size(); ++i) {
      const auto it1 = attr_dict.find(in_arg_names[i]);
      if (it1 != attr_dict.end()) {
        const auto it2 = it1->second.find("__ctx_group__");
        if (it2 != it1->second.end()) {
          const auto it3 = ctx_map.find(it2->second);
          if (it3 != ctx_map.end()) {
            in_arg_ctx_vec[i] = it3->second;
          }
        }
      }
    }

    // initialize aux_state_ctx_vec using group2ctx if there are any
    for (size_t i = 0; i < aux_state_ctx_vec.size(); ++i) {
      const auto it1 = attr_dict.find(aux_state_names[i]);
      if (it1 != attr_dict.end()) {
        const auto it2 = it1->second.find("__ctx_group__");
        if (it2 != it1->second.end()) {
          const auto it3 = ctx_map.find(it2->second);
          if (it3 != ctx_map.end()) {
            aux_state_ctx_vec[i] = it3->second;
          }
        }
      }
    }
  }

  // create provided_grad_req_map
  const std::map<std::string, OpReqType> req_map =
    {{"null", kNullOp}, {"write", kWriteTo}, {"add", kAddTo}};
  std::unordered_map<std::string, std::string> provided_grad_req_map;
  std::string grad_req_type;
  if (0 == provided_grad_req_list_len
      && nullptr == provided_grad_req_names
      && nullptr != provided_grad_req_types) {  // string, grad_req='write'
    CHECK_EQ(req_map.count(provided_grad_req_types[0]), 1U)
      << "grad_req=" << provided_grad_req_types[0] << " is not a valid input in simple_bind; "
      "only \'null\', \'write\', and \'add\' are supported";
    grad_req_type = "string";
  } else if (provided_grad_req_list_len > 0
      && nullptr == provided_grad_req_names
      && nullptr != provided_grad_req_types) {  // list, grad_req=['null', 'write']
    grad_req_type = "list";
    CHECK_EQ(provided_grad_req_list_len, in_arg_names.size())
      << "The length of grad_req list does not match the number of input arguments in simple_bind, "
      "expected " << in_arg_names.size() << ", provided " << provided_grad_req_list_len;
  } else if (provided_grad_req_list_len > 0
      && nullptr != provided_grad_req_names
      && nullptr != provided_grad_req_types) {  // dict, grad_req=['lhs': 'null', 'rhs': 'write']
    grad_req_type = "dict";
    provided_grad_req_map.reserve(provided_grad_req_list_len);
    for (mx_uint i = 0; i < provided_grad_req_list_len; ++i) {
      CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U)
        << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; "
        "only \'null\', \'write\', and \'add\' are supported";
      provided_grad_req_map[provided_grad_req_names[i]] = provided_grad_req_types[i];
    }
  } else {  // grad_req is None
    grad_req_type = "none";
  }

  // initialize arg_grad_ctx_vec and grad_req_type_vec
  std::vector<Context> arg_grad_ctx_vec(in_arg_names.size(), ctx);
  std::vector<OpReqType> grad_req_type_vec(in_arg_names.size(), kNullOp);
  if ("none" != grad_req_type) {
    for (size_t i = 0; i < in_arg_names.size(); ++i) {
      OpReqType cur_req = kNullOp;
      if ("string" == grad_req_type) {
        cur_req = req_map.at(provided_grad_req_types[0]);
      } else if ("list" == grad_req_type) {
        CHECK_EQ(req_map.count(provided_grad_req_types[i]), 1U)
          << "grad_req=" << provided_grad_req_types[i] << " is not a valid input in simple_bind; "
          "only \'null\', \'write\', and \'add\' are supported";
        cur_req = req_map.at(provided_grad_req_types[i]);
      } else if ("dict" == grad_req_type) {
        const auto it = provided_grad_req_map.find(in_arg_names[i]);
        if (it != provided_grad_req_map.end()) {
          cur_req = req_map.at(it->second);
        }
      }
      if (kNullOp != cur_req) {
        arg_grad_ctx_vec[i] = in_arg_ctx_vec[i];
        grad_req_type_vec[i] = static_cast<OpReqType>(cur_req);
      }
    }
  }

  // create shape map for in_args and aux_states
  std::unordered_map<std::string, TShape> arg_shape_map(num_provided_arg_shapes);
  for (mx_uint i = 0; i < num_provided_arg_shapes; ++i) {
    auto p = arg_shape_map.emplace(provided_arg_shape_names[i],
        TShape(provided_arg_shape_data+provided_arg_shape_idx[i],
          provided_arg_shape_data+provided_arg_shape_idx[i+1]));
    CHECK(p.second) << "Duplicate shapes are provided for argument "
      << provided_arg_shape_names[i] << " in simple_bind";
  }

  // create para name set for sharing data array memory
  std::unordered_set<std::string> shared_arg_name_set(num_shared_arg_names);
  for (mx_uint i = 0; i < num_shared_arg_names; ++i) {
    shared_arg_name_set.insert(shared_arg_name_list[i]);
  }

  // create shared_buffer_map
  std::unordered_map<std::string, NDArray> shared_buffer_map;
  bool use_shared_buffer = (*shared_buffer_len >= 0);
  if (*shared_buffer_len > 0) {
    // create shared_buffer_map
    shared_buffer_map.reserve(*shared_buffer_len);
    NDArray** shared_buffer_ptrs =
      reinterpret_cast<NDArray**>(shared_buffer_handle_list);
    for (int i = 0; i < *shared_buffer_len; ++i) {
      shared_buffer_map[shared_buffer_name_list[i]] = *(shared_buffer_ptrs[i]);
    }
  }

  // create temporary place holders for the initialized NDArrays
  // to be passed back to front end
  std::vector<NDArray> in_arg_vec;
  std::vector<NDArray> arg_grad_vec;
  std::vector<NDArray> aux_state_vec;
#if MXNET_USE_TENSORRT
  // If we've built with TensorRT support we by default return an TRTExecutor.
  // Users can override this behaviour via env var, which is useful for example for A/B
  // performance testing.
  if (dmlc::GetEnv("MXNET_USE_TENSORRT", false)) {
    *out = exec::TrtGraphExecutor::TensorRTBind(*sym, ctx, ctx_map, &in_arg_ctx_vec,
                                                &arg_grad_ctx_vec, &aux_state_ctx_vec,
                                                &arg_shape_map, &arg_dtype_map, &arg_stype_map,
                                                &grad_req_type_vec, shared_arg_name_set,
                                                &in_arg_vec, &arg_grad_vec, &aux_state_vec,
                                                use_shared_buffer ? &shared_buffer_map : nullptr,
                                                reinterpret_cast<Executor*>(shared_exec_handle));
  } else {
    // Checks to see if this env var has been set to true or false by the user.
    // If the user is using a TensorRT build, but has not enabled TRT at inference time, warn
    // them and describe further steps.
    const int unset_indicator =  std::numeric_limits<int>::quiet_NaN();
    if (dmlc::GetEnv("MXNET_USE_TENSORRT", unset_indicator) == unset_indicator) {
      LOG(INFO) << "TensorRT not enabled by default.  Please set the MXNET_USE_TENSORRT "
                   "environment variable to 1 or call mx.contrib.tensorrt.set_use_tensorrt(True) "
                   "to enable.";
    }
#endif  // MXNET_USE_TENSORRT
    *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec,
                                aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map,
                                grad_req_type_vec, shared_arg_name_set, &in_arg_vec,
                                &arg_grad_vec, &aux_state_vec,
                                use_shared_buffer ? &shared_buffer_map : nullptr,
                                reinterpret_cast<Executor*>(shared_exec_handle));
#if MXNET_USE_TENSORRT
  }
#endif  // MXNET_USE_TENSORRT

  // copy ndarray ptrs to ret->handles so that front end
  // can access them
  ret->ret_handles.clear();
  ret->ret_handles.reserve(in_arg_vec.size()+arg_grad_vec.size()+aux_state_vec.size()
                           +shared_buffer_map.size());
  size_t nd_idx = 0;
  for (const auto& nd : in_arg_vec) {
    if (nd.is_none()) {
      LOG(FATAL) << "Input argument NDArray cannot be un-allocated";
    }
    ret->ret_handles.push_back(new NDArray(nd));
  }
  if (in_arg_vec.size() > 0) {
    *num_in_args = in_arg_vec.size();
    *in_args = &(ret->ret_handles[nd_idx]);
    nd_idx = ret->ret_handles.size();
  }

  for (const auto& nd : arg_grad_vec) {
    if (nd.is_none()) {
      ret->ret_handles.push_back(nullptr);
    } else {
      ret->ret_handles.push_back(new NDArray(nd));
    }
  }
  if (arg_grad_vec.size() > 0) {
    *arg_grads = &(ret->ret_handles[nd_idx]);
    nd_idx = ret->ret_handles.size();
  }

  for (const auto& nd : aux_state_vec) {
    if (nd.is_none()) {
      LOG(FATAL) << "Auxiliary argument NDArray cannot be un-allocated";
    }
    ret->ret_handles.push_back(new NDArray(nd));
  }
  if (aux_state_vec.size() > 0) {
    *num_aux_states = aux_state_vec.size();
    *aux_states = &(ret->ret_handles[nd_idx]);
    nd_idx = ret->ret_handles.size();
  }

  if (use_shared_buffer) {
    ret->ret_vec_str.clear();
    ret->ret_vec_str.reserve(shared_buffer_map.size());
    ret->ret_vec_charp.clear();
    ret->ret_vec_charp.reserve(shared_buffer_map.size());
    for (const auto& kv : shared_buffer_map) {
      if (kv.second.is_none()) {
        LOG(FATAL) << "Shared data NDArray cannot be un-allocated";
      }
      ret->ret_handles.push_back(new NDArray(kv.second));
      ret->ret_vec_str.emplace_back(kv.first);
      ret->ret_vec_charp.push_back(ret->ret_vec_str.back().c_str());
    }
    *shared_buffer_len = shared_buffer_map.size();
    *updated_shared_buffer_handle_list = &(ret->ret_handles[nd_idx]);
    *updated_shared_buffer_name_list = &(ret->ret_vec_charp[0]);
  }

  API_END();
}