Status EnqueueTensorAlltoall()

in horovod/common/operations.cc [979:1040]


Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
                             std::shared_ptr<Tensor> tensor,
                             std::shared_ptr<Tensor> splits,
                             std::shared_ptr<ReadyEvent> ready_event,
                             const std::string name, const int device,
                             StatusCallback callback) {
  // Check arguments
  if (splits->shape().dims() > 1) {
    return Status::InvalidArgument("alltoall expects a 1D splits tensor");
  }
  if (splits->dtype() != HOROVOD_INT32) {
    return Status::InvalidArgument("alltoall expects splits to contain 32-bit integer elements.");
  }

  Request message;
  message.set_request_rank(horovod_global.controller->GetRank());
  message.set_tensor_name(name);
  message.set_tensor_type(tensor->dtype());
  message.set_device(device);
  message.set_request_type(Request::ALLTOALL);
  for (int i = 0; i < tensor->shape().dims(); ++i) {
    message.add_tensor_shape((int64_t)tensor->shape().dim_size(i));
  }

  TensorTableEntry e;
  e.tensor_name = name;
  e.context = context;
  e.tensor = tensor;
  e.ready_event = ready_event;
  e.device = device;
  e.callback = callback;

  int64_t splits_first_dim = splits->shape().dim_size(0);
  int64_t tensor_first_dim = tensor->shape().dim_size(0);
  int world_size = horovod_global.controller->GetSize();
  if (splits_first_dim == world_size) {
    auto splits_data = static_cast<const int32_t*>(splits->data());
    auto sum = std::accumulate(splits_data, splits_data + splits_first_dim, 0);
    if (sum > tensor_first_dim) {
      return Status::InvalidArgument("Sum of splits entries is greater than the first dimension of tensor.");
    }
    e.splits.assign(splits_data,
                    splits_data + splits->shape().num_elements());
  } else if (splits_first_dim == 0) {
    if (tensor_first_dim % world_size != 0) {
      return Status::InvalidArgument("splits not provided, but first dimension of tensor is not an even "
                                     "multiple of the number of workers.");
    }
    e.splits.resize(world_size, tensor_first_dim / world_size);
  } else {
      return Status::InvalidArgument("Number of entries in splits does not equal number of workers.");
  }

  if (horovod_global.shut_down) {
    return SHUT_DOWN_ERROR;
  }
  Status status = horovod_global.tensor_queue.AddToTensorQueue(e, message);
  if (status.ok()) {
    LOG(TRACE, horovod_global.controller->GetRank()) << "Enqueued " << name;
  }
  return status;
}