in horovod/common/operations.cc [979:1040]
Status EnqueueTensorAlltoall(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> splits,
std::shared_ptr<ReadyEvent> ready_event,
const std::string name, const int device,
StatusCallback callback) {
// Check arguments
if (splits->shape().dims() > 1) {
return Status::InvalidArgument("alltoall expects a 1D splits tensor");
}
if (splits->dtype() != HOROVOD_INT32) {
return Status::InvalidArgument("alltoall expects splits to contain 32-bit integer elements.");
}
Request message;
message.set_request_rank(horovod_global.controller->GetRank());
message.set_tensor_name(name);
message.set_tensor_type(tensor->dtype());
message.set_device(device);
message.set_request_type(Request::ALLTOALL);
for (int i = 0; i < tensor->shape().dims(); ++i) {
message.add_tensor_shape((int64_t)tensor->shape().dim_size(i));
}
TensorTableEntry e;
e.tensor_name = name;
e.context = context;
e.tensor = tensor;
e.ready_event = ready_event;
e.device = device;
e.callback = callback;
int64_t splits_first_dim = splits->shape().dim_size(0);
int64_t tensor_first_dim = tensor->shape().dim_size(0);
int world_size = horovod_global.controller->GetSize();
if (splits_first_dim == world_size) {
auto splits_data = static_cast<const int32_t*>(splits->data());
auto sum = std::accumulate(splits_data, splits_data + splits_first_dim, 0);
if (sum > tensor_first_dim) {
return Status::InvalidArgument("Sum of splits entries is greater than the first dimension of tensor.");
}
e.splits.assign(splits_data,
splits_data + splits->shape().num_elements());
} else if (splits_first_dim == 0) {
if (tensor_first_dim % world_size != 0) {
return Status::InvalidArgument("splits not provided, but first dimension of tensor is not an even "
"multiple of the number of workers.");
}
e.splits.resize(world_size, tensor_first_dim / world_size);
} else {
return Status::InvalidArgument("Number of entries in splits does not equal number of workers.");
}
if (horovod_global.shut_down) {
return SHUT_DOWN_ERROR;
}
Status status = horovod_global.tensor_queue.AddToTensorQueue(e, message);
if (status.ok()) {
LOG(TRACE, horovod_global.controller->GetRank()) << "Enqueued " << name;
}
return status;
}