Status EnqueueTensorAllreduce()

in horovod/common/operations.cc [840:901]


Status EnqueueTensorAllreduce(std::shared_ptr<OpContext> context,
                              std::shared_ptr<Tensor> tensor,
                              std::shared_ptr<Tensor> output,
                              std::shared_ptr<ReadyEvent> ready_event,
                              const std::string name, const int device,
                              StatusCallback callback,
                              ReduceOp reduce_op,
                              double prescale_factor,
                              double postscale_factor) {
  Status status;

  if (reduce_op == ReduceOp::AVERAGE) {
#if !HAVE_ROCM
    // Averaging happens via postscale_factor
    postscale_factor /= horovod_global.controller->GetSize();
#else
    LOG(ERROR, horovod_global.controller->GetRank()) << "Enqueuing AVERAGE allreduce is not allowed.";
    return status.Aborted("AVERAGE not allowed.");
#endif
  } else if (reduce_op == ReduceOp::ADASUM) {
#if HAVE_NCCL && !HAVE_ROCM
    if (device != CPU_DEVICE_ID) {
      // Averaging by local size happens via postscale_factor
      postscale_factor /= horovod_global.controller->GetLocalSize();
    }
#endif
  }
  Request message;
  message.set_request_rank(horovod_global.controller->GetRank());
  message.set_tensor_name(name);
  message.set_tensor_type(tensor->dtype());
  message.set_device(device);
  message.set_prescale_factor(prescale_factor);
  message.set_postscale_factor(postscale_factor);
  
  if (reduce_op == ReduceOp::ADASUM) {
    message.set_request_type(Request::ADASUM);
  } else {
    message.set_request_type(Request::ALLREDUCE);
  }
  for (int i = 0; i < tensor->shape().dims(); ++i) {
    message.add_tensor_shape((int64_t)tensor->shape().dim_size(i));
  }

  TensorTableEntry e;
  e.tensor_name = name;
  e.context = context;
  e.tensor = tensor;
  e.output = output;
  e.ready_event = ready_event;
  e.device = device;
  e.callback = callback;

  if (horovod_global.shut_down) {
    return SHUT_DOWN_ERROR;
  }
  status = horovod_global.tensor_queue.AddToTensorQueue(e, message);
  if (status.ok()) {
    LOG(TRACE, horovod_global.controller->GetRank()) << "Enqueued " << name;
  }
  return status;
}