in horovod/mxnet/mpi_ops.cc [75:137]
void DoHorovodOperation(void*, void* on_complete_ptr, void* param) {
ThrowIfError(common::CheckInitialized());
auto on_complete = *static_cast<CallbackOnComplete*>(on_complete_ptr);
auto ops_param = static_cast<MpiOpsParam*>(param);
auto input_tensor = ops_param->input_tensor.get();
auto output_tensor = ops_param->output_tensor.get();
auto output = ops_param->output;
auto name = ops_param->op_name;
auto average = ops_param->average;
auto prescale_factor = ops_param->prescale_factor;
auto postscale_factor = ops_param->postscale_factor;
auto device = TensorUtil::GetDevice(input_tensor);
auto hvd_tensor = std::make_shared<MXTensor>(input_tensor);
auto hvd_context = std::make_shared<MXOpContext>(device, output);
std::shared_ptr<Tensor> hvd_output = nullptr;
Status enqueue_result;
switch (ops_param->op_type) {
case OperationType::ALLREDUCE:
hvd_output = std::make_shared<MXTensor>(output_tensor);
enqueue_result = EnqueueTensorAllreduce(
hvd_context, hvd_tensor, hvd_output, nullptr, name, device,
[on_complete](const Status& status) {
InvokeCompleteCallback(on_complete, status);
}, (average) ? ReduceOp::AVERAGE : ReduceOp::SUM, prescale_factor, postscale_factor);
break;
case OperationType::ALLGATHER:
enqueue_result = EnqueueTensorAllgather(
hvd_context, hvd_tensor, nullptr, name, device,
[on_complete](const Status& status) {
InvokeCompleteCallback(on_complete, status);
});
break;
case OperationType::BROADCAST:
if (horovod_rank() != ops_param->root_rank) {
hvd_output = std::make_shared<MXTensor>(output_tensor);
}
enqueue_result = EnqueueTensorBroadcast(
hvd_context, hvd_tensor, hvd_output, ops_param->root_rank,
nullptr, name, device,
[on_complete](const Status& status) {
InvokeCompleteCallback(on_complete, status);
});
break;
case OperationType::ALLTOALL:
{
auto hvd_splits = std::make_shared<MXTensor>(ops_param->splits_tensor.get());
enqueue_result = EnqueueTensorAlltoall(
hvd_context, hvd_tensor, hvd_splits, nullptr, name, device,
[on_complete](const Status& status) {
InvokeCompleteCallback(on_complete, status);
});
break;
}
default:
throw std::logic_error("Unsupported Horovod operation type.");
}
ThrowIfError(enqueue_result);
}