in lib/distributed_runtime/kernels.cc [704:830]
void DoAllGatherAnyShape(const ExecutionContext& exec_ctx,
AsyncValueRef<DistributedContext> dist_ctx,
const InstanceKey& instance_key,
const CollectiveGroup& collective_group,
const int my_index, const DenseHostTensor& in_tensor,
AsyncValueRef<DenseHostTensor> out_tensor,
DenseHostTensor& shapes_tensor,
MutableDHTArrayView<Index> shape_tensor_view,
AsyncValueRef<Chain> out_chain, size_t axis,
size_t kGroupSize, size_t kRank) {
llvm::SmallVector<llvm::SmallVector<size_t, 4>, 4> offsets;
llvm::SmallVector<size_t, 4> step_sizes;
llvm::SmallVector<Index, 4> dimensions;
llvm::SmallVector<Index, 4> tensor_sizes;
size_t gathered_dimension = 0;
DHTArrayView<size_t> shapes_array(&shapes_tensor);
// This is to check that all tensors have the same rank.
// E.g. AllGather on tensors of shape 2x2 and 2x2x2 does not make sense.
if (shapes_array.NumElements() / kGroupSize != kRank) {
out_chain.SetError("All workers must have tensors of the same rank.");
return;
}
// This goes through the participating shapes and compute the following.
// - dimension of output tensor
// - number of elements in each participating tensor
// - step_size for each tensor
for (size_t i = 0; i < kGroupSize; ++i) {
size_t num_element = 1;
size_t step_size = 1;
for (size_t j = 0; j < kRank; ++j) {
// This refers to i-th tensor's j-th dimension.
// Example:
// Suppose there are 2 tensors to be gathered.
// T0 has shape 1x2 and T1 has shape 1x3.
// Then, we have
// shapes_array = [1,2,1,3]
// dim is used to access T0's 1,2 and T1's 1,3.
size_t dim = kRank * i + j;
if (j != axis) {
if (shapes_array[dim] != shape_tensor_view.Elements()[j]) {
out_chain.SetError(
"All dimensions in the input must be equal except the axis");
return;
}
} else {
gathered_dimension += shapes_array[dim];
}
num_element *= shapes_array[dim];
if (j >= axis) {
step_size *= shapes_array[dim];
}
}
tensor_sizes.push_back(num_element);
step_sizes.push_back(step_size);
}
// Create a vector of output dimension
for (size_t l = 0; l < kRank; ++l) {
if (l != axis) {
dimensions.push_back(shape_tensor_view.Elements()[l]);
} else {
dimensions.push_back(gathered_dimension);
}
}
// Compute an interval (distance between offsets).
// See line 658 for more details.
size_t interval = 1;
for (size_t m = axis; m < kRank; ++m) {
interval *= dimensions[m];
}
// For each participating tensor, compute offsets (positions to be copied to
// in output tensor)
size_t pos = 0;
for (size_t n = 0; n < kGroupSize; ++n) {
llvm::SmallVector<size_t, 4> offset;
size_t num_offsets = tensor_sizes[n] / step_sizes[n];
size_t each_offset = pos;
for (size_t o = 0; o < num_offsets; ++o) {
offset.push_back(each_offset);
each_offset += interval;
}
offsets.push_back(offset);
pos += step_sizes[n];
}
// Create an output tensor
TensorShape shape(dimensions);
TensorMetadata md(in_tensor.metadata().dtype, shape);
auto output_tensor =
DenseHostTensor::MakeConstructedAsyncValueRef(md, exec_ctx.host());
auto done = [output_tensor = output_tensor.CopyRef(),
out_chain = std::move(out_chain),
out_tensor = std::move(out_tensor),
dist_ctx = dist_ctx.CopyRef()](Error e) mutable {
if (e) {
out_chain.SetError(e);
} else {
out_chain.emplace();
out_tensor.emplace(std::move(output_tensor.get()));
}
};
auto refcounted_done = TakeRef(
new RefCountedCallback([host = dist_ctx->GetHostContext(), exec_ctx,
done = std::move(done)](Error e) mutable {
if (host->IsInWorkerThread()) {
done(std::move(e));
} else {
EnqueueWork(exec_ctx,
[done = std::move(done), e = std::move(e)]() mutable {
done(std::move(e));
});
}
}));
// Do a final AllGather
EnqueueWork(exec_ctx, [exec_ctx, my_index, instance_key, axis,
collective_group, dist_ctx = dist_ctx.CopyRef(),
in_tensor_ref = in_tensor.CopyRef(),
output_tensor = output_tensor.CopyRef(),
refcounted_done = refcounted_done,
offsets = std::move(offsets),
step_sizes = std::move(step_sizes)] {
DoAllGather<T>(exec_ctx, dist_ctx.CopyRef(), instance_key, collective_group,
my_index, in_tensor_ref.CopyRef(), output_tensor.get(),
refcounted_done, axis, offsets, step_sizes);
});
}