void DoHorovodOperation()

in horovod/mxnet/mpi_ops.cc [75:137]


void DoHorovodOperation(void*, void* on_complete_ptr, void* param) {
  ThrowIfError(common::CheckInitialized());

  auto on_complete = *static_cast<CallbackOnComplete*>(on_complete_ptr);
  auto ops_param = static_cast<MpiOpsParam*>(param);
  auto input_tensor = ops_param->input_tensor.get();
  auto output_tensor = ops_param->output_tensor.get();
  auto output = ops_param->output;
  auto name = ops_param->op_name;
  auto average = ops_param->average;
  auto prescale_factor = ops_param->prescale_factor;
  auto postscale_factor = ops_param->postscale_factor;
  auto device = TensorUtil::GetDevice(input_tensor);

  auto hvd_tensor = std::make_shared<MXTensor>(input_tensor);
  auto hvd_context = std::make_shared<MXOpContext>(device, output);  
  std::shared_ptr<Tensor> hvd_output = nullptr;  

  Status enqueue_result;
  switch (ops_param->op_type) {
    case OperationType::ALLREDUCE:
      hvd_output = std::make_shared<MXTensor>(output_tensor);
      enqueue_result = EnqueueTensorAllreduce(
          hvd_context, hvd_tensor, hvd_output, nullptr, name, device,
          [on_complete](const Status& status) {
            InvokeCompleteCallback(on_complete, status);
      }, (average) ? ReduceOp::AVERAGE : ReduceOp::SUM, prescale_factor, postscale_factor);
      break;
    case OperationType::ALLGATHER:
      enqueue_result = EnqueueTensorAllgather(
          hvd_context, hvd_tensor, nullptr, name, device,
          [on_complete](const Status& status) {
            InvokeCompleteCallback(on_complete, status);
      });
      break;
    case OperationType::BROADCAST:
      if (horovod_rank() != ops_param->root_rank) {
        hvd_output = std::make_shared<MXTensor>(output_tensor);
      }

      enqueue_result = EnqueueTensorBroadcast(
          hvd_context, hvd_tensor, hvd_output, ops_param->root_rank,
          nullptr, name, device,
          [on_complete](const Status& status) {
            InvokeCompleteCallback(on_complete, status);
      });
      break;
    case OperationType::ALLTOALL:
    {
      auto hvd_splits = std::make_shared<MXTensor>(ops_param->splits_tensor.get());
      enqueue_result = EnqueueTensorAlltoall(
          hvd_context, hvd_tensor, hvd_splits, nullptr, name, device,
          [on_complete](const Status& status) {
            InvokeCompleteCallback(on_complete, status);
      });
      break;
    }
    default:
      throw std::logic_error("Unsupported Horovod operation type.");
  }

  ThrowIfError(enqueue_result);
}