in src/imperative/imperative_utils.h [1222:1407]
inline Engine::OprHandle CreateEngineOp(
const Context& default_ctx,
const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
const char* opr_names) {
CHECK_GT(execs.size(), 0);
std::vector<Engine::VarHandle> use_vars, mutate_vars;
for (const auto& exec : execs) {
CHECK_GT(exec->out_array.size(), 0);
CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);
// the variables
for (const auto& nd : exec->in_array) {
use_vars.push_back(nd.var());
}
for (auto& r : exec->op_ctx.requested) {
mutate_vars.push_back(r.var);
}
for (auto& nd : exec->out_array) {
mutate_vars.push_back(nd.var());
}
if (exec->var() != nullptr) {
mutate_vars.push_back(exec->var());
}
}
// dedup vars
Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;
#if CUDA_GRAPHS_AVAILABLE
// Provide initialized `cuda_graphs_exec`, which when captured
// by exec_fun, acts like a static variable inside the mutable closure.
cuda_graphs::CudaGraphsExec cuda_graphs_exec(execs, is_gpu, opr_names);
auto exec_fun = [cuda_graphs_exec, execs, is_async, is_gpu](
RunContext ctx,
Engine::CallbackOnStart on_start,
Engine::CallbackOnComplete on_complete) mutable {
on_start();
if (is_async) {
execs[0]->op_ctx.async_on_complete = on_complete;
}
// Run all opr in the sub-graph with CUDA graphs executor if possible
cuda_graphs_exec.RunAll(execs, ctx, is_gpu);
#else
auto exec_fun = [execs, is_async, is_gpu](RunContext ctx,
Engine::CallbackOnStart on_start,
Engine::CallbackOnComplete on_complete) {
on_start();
if (is_async) {
execs[0]->op_ctx.async_on_complete = on_complete;
}
exec::OpExecutor::RunAll(execs, ctx, is_gpu);
#endif
// call on complete only if it is async op
if (!is_async) {
if (is_gpu) {
#if !MXNET_USE_CUDA
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}
on_complete();
}
};
return Engine::Get()->NewOperator(
exec_fun, use_vars, mutate_vars, FnProperty::kNormal, opr_names);
}
inline void CreateEngineOpSeg(const nnvm::IndexedGraph& idx,
const Context default_ctx,
const size_t start_nid,
const size_t end_nid,
const size_t bulk_size,
const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
const std::vector<int> skip_plus_node,
std::vector<EngineOprSeg>* opr_segs) {
size_t seg_start = start_nid;
std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
std::string opr_names = "[";
for (size_t nid = start_nid; nid < end_nid; ++nid) {
const auto& node = idx[nid];
if (node.source->is_variable())
continue;
if (skip_plus_node.size() && skip_plus_node[nid])
continue;
auto& exec = execs[nid];
const auto& op_name = node.source->op()->name;
bool is_async = exec->exec_type() != ExecType::kSync;
bool valid = exec->out_array.size() > 0;
// Stop at async nodes and invalid node (due to input/output is not allocated)
bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
// Create opr segment for previous nodes.
if (stop && nid > seg_start) {
auto& seg = (*opr_segs)[seg_start];
if (seg_execs.size()) {
seg = EngineOprSeg{false, nid};
opr_names.pop_back();
opr_names += "]";
seg.opr.reset(CreateEngineOp(default_ctx, seg_execs, opr_names.c_str()));
} else {
seg = EngineOprSeg{true, nid, nullptr};
}
seg_start = nid;
seg_execs.clear();
opr_names.clear();
}
seg_execs.push_back(exec);
const auto& inode = idx[nid];
opr_names += op_name;
opr_names += "{name=" + inode.source->attrs.name + ";";
const std::unordered_map<std::string, std::string>& dict = inode.source->attrs.dict;
auto num_dict_entries = dict.size();
for (auto& k : dict) {
opr_names += k.first + "=" + k.second;
if (--num_dict_entries != 0)
opr_names += ";";
}
opr_names += "},";
auto& seg = (*opr_segs)[nid];
if (!valid) {
seg = EngineOprSeg{false, nid + 1, nullptr};
seg_execs.clear();
opr_names.clear();
seg_start = nid + 1;
} else if (is_async) {
seg = EngineOprSeg{false, nid + 1};
opr_names.pop_back();
opr_names += "]";
seg.opr.reset(CreateEngineOp(default_ctx, seg_execs, opr_names.c_str()));
seg_execs.clear();
opr_names.clear();
seg_start = nid + 1;
}
}
// The last segment
if (end_nid > seg_start) {
auto& seg = (*opr_segs)[seg_start];
if (seg_execs.size()) {
seg = EngineOprSeg{false, end_nid};
opr_names.pop_back();
opr_names += "]";
seg.opr.reset(CreateEngineOp(default_ctx, seg_execs, opr_names.c_str()));
} else {
seg = EngineOprSeg{true, end_nid, nullptr};
}
}
}
void RunGraph(const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*>& arrays,
size_t node_start,
size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr>* p_states,
const DispatchModeVector& dispatch_modes,
bool recording,
mxnet::ShapeVector* shapes = nullptr,
const CachedOpMonCallback& callback = nullptr,
const bool monitor_all_ = false);
void NaiveRunGraph(const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*>& arrays,
size_t node_start,
size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
std::vector<OpStatePtr>* p_states,
const DispatchModeVector& dispatch_modes,
bool recording,
mxnet::ShapeVector* shapes,
const CachedOpMonCallback& callback = nullptr,
const bool monitor_all_ = false,
const bool skip_engine = false);
} // namespace imperative