in horovod/mxnet/mpi_ops.cc [139:195]
inline void PushHorovodOperation(OperationType op_type, NDArray* input,
NDArray* output, const char* name,
int priority, int root_rank = -1,
bool average = true,
NDArray* splits = nullptr,
double prescale_factor = 1.0,
double postscale_factor = 1.0) {
auto op_type_name = GetOpTypeName(op_type);
auto op_name = GetOpName(op_type_name, name);
// We need to create a shared_ptr to NDArray object with
// shallow copy to prevent from NDArray object being freed
// before MXNet engine process it
auto input_copy = std::make_shared<NDArray>(*input);
auto output_copy = std::make_shared<NDArray>(*output);
std::shared_ptr<NDArray> splits_tensor;
if (splits) {
#if HAVE_CUDA
// We expect splits to be a tensor on CPU. Create CPU copy if required.
if (!IsTensorOnCPU(splits)) {
splits_tensor = std::make_shared<NDArray>(Context::Create(Context::kCPU, 0),
splits->dtype());
TensorUtil::AsyncCopyCudaToCPU(splits, splits_tensor.get());
} else {
splits_tensor = std::make_shared<NDArray>(*splits);
}
#else
splits_tensor = std::make_shared<NDArray>(*splits);
#endif
}
auto ops_param = CreateMpiOpsParam(input_copy, output_copy, output,
nullptr /* cpu_input_tensor */, nullptr /* cpu_output_tensor */,
op_type, op_name, root_rank, average, splits_tensor, prescale_factor, postscale_factor);
// Not in-place
auto input_var = input->var();
auto output_var = output->var();
if (input_var != output_var) {
std::vector<void*> input_vars {input_var};
if (splits) {
// Add splits tensor to input list to enforce dependency on possible async D2H copy
input_vars.push_back(splits_tensor->var());
}
MXEnginePushAsync(DoHorovodOperation, ops_param, DeleteMpiOpsParam,
&MX_EXEC_CTX, input_vars.data(), input_vars.size(), &output_var, 1,
&MX_FUNC_PROP, priority, op_type_name);
// In-place
} else {
std::vector<void*> input_vars;
if (splits) {
input_vars.push_back(splits_tensor->var());
}
MXEnginePushAsync(DoHorovodOperation, ops_param, DeleteMpiOpsParam,
&MX_EXEC_CTX, input_vars.data(), input_vars.size(), &output_var, 1,
&MX_FUNC_PROP, priority, op_type_name);
}
}