inline void PushHorovodOperation()

in horovod/mxnet/mpi_ops.cc [139:195]


inline void PushHorovodOperation(OperationType op_type, NDArray* input,
                                 NDArray* output, const char* name,
                                 int priority, int root_rank = -1,
                                 bool average = true,
                                 NDArray* splits = nullptr,
                                 double prescale_factor = 1.0,
                                 double postscale_factor = 1.0) {
  auto op_type_name = GetOpTypeName(op_type);
  auto op_name = GetOpName(op_type_name, name);

  // We need to create a shared_ptr to NDArray object with
  // shallow copy to prevent from NDArray object being freed
  // before MXNet engine process it
  auto input_copy = std::make_shared<NDArray>(*input);
  auto output_copy = std::make_shared<NDArray>(*output);
  std::shared_ptr<NDArray> splits_tensor;
  if (splits) {
#if HAVE_CUDA
    // We expect splits to be a tensor on CPU. Create CPU copy if required.
    if (!IsTensorOnCPU(splits)) {
      splits_tensor = std::make_shared<NDArray>(Context::Create(Context::kCPU, 0),
      splits->dtype());
      TensorUtil::AsyncCopyCudaToCPU(splits, splits_tensor.get());
    } else {
      splits_tensor = std::make_shared<NDArray>(*splits);
    }
#else
    splits_tensor = std::make_shared<NDArray>(*splits);
#endif
  }
  auto ops_param = CreateMpiOpsParam(input_copy, output_copy, output,
    nullptr /* cpu_input_tensor */, nullptr /* cpu_output_tensor */,
    op_type, op_name, root_rank, average, splits_tensor, prescale_factor, postscale_factor);

  // Not in-place
  auto input_var = input->var();
  auto output_var = output->var();
  if (input_var != output_var) {
    std::vector<void*> input_vars {input_var};
    if (splits) {
      // Add splits tensor to input list to enforce dependency on possible async D2H copy
      input_vars.push_back(splits_tensor->var());
    }
    MXEnginePushAsync(DoHorovodOperation, ops_param, DeleteMpiOpsParam,
                      &MX_EXEC_CTX, input_vars.data(), input_vars.size(), &output_var, 1,
                      &MX_FUNC_PROP, priority, op_type_name);
  // In-place
  } else {
    std::vector<void*> input_vars;
    if (splits) {
      input_vars.push_back(splits_tensor->var());
    }
    MXEnginePushAsync(DoHorovodOperation, ops_param, DeleteMpiOpsParam,
                      &MX_EXEC_CTX, input_vars.data(), input_vars.size(), &output_var, 1,
                      &MX_FUNC_PROP, priority, op_type_name);
  }
}