inline Engine::OprHandle CreateEngineOp()

in src/imperative/imperative_utils.h [1222:1407]


inline Engine::OprHandle CreateEngineOp(
    const Context& default_ctx,
    const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
    const char* opr_names) {
  CHECK_GT(execs.size(), 0);
  std::vector<Engine::VarHandle> use_vars, mutate_vars;

  for (const auto& exec : execs) {
    CHECK_GT(exec->out_array.size(), 0);
    CHECK(execs.size() == 1 || exec->exec_type() == ExecType::kSync);

    // the variables
    for (const auto& nd : exec->in_array) {
      use_vars.push_back(nd.var());
    }
    for (auto& r : exec->op_ctx.requested) {
      mutate_vars.push_back(r.var);
    }
    for (auto& nd : exec->out_array) {
      mutate_vars.push_back(nd.var());
    }
    if (exec->var() != nullptr) {
      mutate_vars.push_back(exec->var());
    }
  }

  // dedup vars
  Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
  bool is_gpu   = default_ctx.dev_mask() == gpu::kDevMask;
  bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync;

#if CUDA_GRAPHS_AVAILABLE
  // Provide initialized `cuda_graphs_exec`, which when captured
  // by exec_fun, acts like a static variable inside the mutable closure.
  cuda_graphs::CudaGraphsExec cuda_graphs_exec(execs, is_gpu, opr_names);
  auto exec_fun = [cuda_graphs_exec, execs, is_async, is_gpu](
                      RunContext ctx,
                      Engine::CallbackOnStart on_start,
                      Engine::CallbackOnComplete on_complete) mutable {
    on_start();
    if (is_async) {
      execs[0]->op_ctx.async_on_complete = on_complete;
    }
    // Run all opr in the sub-graph with CUDA graphs executor if possible
    cuda_graphs_exec.RunAll(execs, ctx, is_gpu);
#else
  auto exec_fun = [execs, is_async, is_gpu](RunContext ctx,
                                            Engine::CallbackOnStart on_start,
                                            Engine::CallbackOnComplete on_complete) {
    on_start();
    if (is_async) {
      execs[0]->op_ctx.async_on_complete = on_complete;
    }
    exec::OpExecutor::RunAll(execs, ctx, is_gpu);
#endif
    // call on complete only if it is async op
    if (!is_async) {
      if (is_gpu) {
#if !MXNET_USE_CUDA
        LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
      }
      on_complete();
    }
  };

  return Engine::Get()->NewOperator(
      exec_fun, use_vars, mutate_vars, FnProperty::kNormal, opr_names);
}

inline void CreateEngineOpSeg(const nnvm::IndexedGraph& idx,
                              const Context default_ctx,
                              const size_t start_nid,
                              const size_t end_nid,
                              const size_t bulk_size,
                              const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
                              const std::vector<int> skip_plus_node,
                              std::vector<EngineOprSeg>* opr_segs) {
  size_t seg_start = start_nid;
  std::vector<std::shared_ptr<exec::OpExecutor> > seg_execs;
  std::string opr_names = "[";
  for (size_t nid = start_nid; nid < end_nid; ++nid) {
    const auto& node = idx[nid];
    if (node.source->is_variable())
      continue;
    if (skip_plus_node.size() && skip_plus_node[nid])
      continue;
    auto& exec          = execs[nid];
    const auto& op_name = node.source->op()->name;
    bool is_async       = exec->exec_type() != ExecType::kSync;
    bool valid          = exec->out_array.size() > 0;

    // Stop at async nodes and invalid node (due to input/output is not allocated)
    bool stop = is_async || !valid || seg_execs.size() >= bulk_size;

    // Create opr segment for previous nodes.
    if (stop && nid > seg_start) {
      auto& seg = (*opr_segs)[seg_start];
      if (seg_execs.size()) {
        seg = EngineOprSeg{false, nid};
        opr_names.pop_back();
        opr_names += "]";
        seg.opr.reset(CreateEngineOp(default_ctx, seg_execs, opr_names.c_str()));
      } else {
        seg = EngineOprSeg{true, nid, nullptr};
      }
      seg_start = nid;
      seg_execs.clear();
      opr_names.clear();
    }

    seg_execs.push_back(exec);

    const auto& inode = idx[nid];
    opr_names += op_name;
    opr_names += "{name=" + inode.source->attrs.name + ";";
    const std::unordered_map<std::string, std::string>& dict = inode.source->attrs.dict;
    auto num_dict_entries                                    = dict.size();
    for (auto& k : dict) {
      opr_names += k.first + "=" + k.second;
      if (--num_dict_entries != 0)
        opr_names += ";";
    }
    opr_names += "},";

    auto& seg = (*opr_segs)[nid];
    if (!valid) {
      seg = EngineOprSeg{false, nid + 1, nullptr};
      seg_execs.clear();
      opr_names.clear();
      seg_start = nid + 1;
    } else if (is_async) {
      seg = EngineOprSeg{false, nid + 1};
      opr_names.pop_back();
      opr_names += "]";
      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs, opr_names.c_str()));
      seg_execs.clear();
      opr_names.clear();
      seg_start = nid + 1;
    }
  }
  // The last segment
  if (end_nid > seg_start) {
    auto& seg = (*opr_segs)[seg_start];
    if (seg_execs.size()) {
      seg = EngineOprSeg{false, end_nid};
      opr_names.pop_back();
      opr_names += "]";
      seg.opr.reset(CreateEngineOp(default_ctx, seg_execs, opr_names.c_str()));
    } else {
      seg = EngineOprSeg{true, end_nid, nullptr};
    }
  }
}

void RunGraph(const bool retain_graph,
              const nnvm::IndexedGraph& idx,
              const std::vector<NDArray*>& arrays,
              size_t node_start,
              size_t node_end,
              std::vector<OpReqType>&& array_reqs,
              std::vector<uint32_t>&& ref_count,
              std::vector<OpStatePtr>* p_states,
              const DispatchModeVector& dispatch_modes,
              bool recording,
              mxnet::ShapeVector* shapes          = nullptr,
              const CachedOpMonCallback& callback = nullptr,
              const bool monitor_all_             = false);

void NaiveRunGraph(const bool retain_graph,
                   const Context& default_ctx,
                   const nnvm::IndexedGraph& idx,
                   const std::vector<NDArray*>& arrays,
                   size_t node_start,
                   size_t node_end,
                   std::vector<OpReqType>&& array_reqs,
                   std::vector<uint32_t>&& ref_count,
                   std::vector<OpStatePtr>* p_states,
                   const DispatchModeVector& dispatch_modes,
                   bool recording,
                   mxnet::ShapeVector* shapes,
                   const CachedOpMonCallback& callback = nullptr,
                   const bool monitor_all_             = false,
                   const bool skip_engine              = false);

}  // namespace imperative