in lib/distributed_runtime/kernels.cc [557:701]
void AllGatherFixedShape(Argument<DistributedContext> dist_ctx,
Argument<std::string> collective_group_name,
Argument<InstanceKey> instance_key,
Argument<DenseHostTensor> in_tensor,
Argument<DenseHostTensor> out_tensor,
Argument<size_t> axis, Argument<Chain> in_chain,
Result<Chain> out_chain,
const ExecutionContext& exec_ctx) {
auto out_chain_indirect = out_chain.Allocate();
const auto& collective_group =
dist_ctx->GetCollectiveGroup(collective_group_name.get());
int my_index =
FindMyIndex(collective_group.members, dist_ctx->GetTaskHandle());
if (my_index == -1) {
out_chain_indirect.SetError(
"This worker is not part of the collective group ");
return;
}
const auto kGroupSize = collective_group.members.size();
auto done = [out_tensor = out_tensor.ValueRef(), instance_key = *instance_key,
out_chain = std::move(out_chain_indirect),
dist_ctx = dist_ctx.ValueRef()](Error e) mutable {
if (e) {
out_chain.SetError(e);
} else {
out_chain.emplace();
}
};
auto refcounted_done = TakeRef(
new RefCountedCallback([host = dist_ctx->GetHostContext(), exec_ctx,
done = std::move(done)](Error e) mutable {
if (host->IsInWorkerThread()) {
done(std::move(e));
} else {
EnqueueWork(exec_ctx,
[done = std::move(done), e = std::move(e)]() mutable {
done(std::move(e));
});
}
}));
// Compute offsets and step_sizes for all workers
// offsets: offsets in output tensor for each tensor
// step_sizes: # of elements to be copied for each offset
// Example:
// t0 = [[0,1,2],[3,4,5]] dimension = 2x3
// t0 = [[60,70,80],[90,100,110] dimension = 2x3
llvm::SmallVector<llvm::SmallVector<size_t, 4>, 4> offsets;
llvm::SmallVector<size_t, 4> step_sizes;
size_t num_elements = in_tensor->NumElements();
if (*axis == 0) {
// if axis = 0, whole tensor (# elements = 6) is concat after one another
// offsets = [[0],[6]]
// step_sizes = [6,6]
// result = [[0,1,2],[3,4,5],[60,70,80],[90,100,110]] dimension = 4x3
size_t pos = 0;
for (size_t i = 0; i < kGroupSize; ++i) {
step_sizes.push_back(num_elements);
offsets.push_back({pos});
pos += num_elements;
}
} else {
// otherwise, more work to be done to figure out the offsets and step_sizes
// as we no longer simply concat the whole tensor.
// offsets = [[0,6],[3,9]]
// step_sizes = [3,3]
// For worker_0's tensor, it is split into 2 chunks of 3 elements each.
// each chunk should be copied into out_tensor at position 0 and 6.
// For worker_1's tensor, it is split into 2 chunks of 3 elements each.
// each chunk should be copied into out_tensor at position 3 and 9.
// result = [[0,1,2,60,70,80],[3,4,5,90,100,110]] dimension = 2x6
llvm::SmallVector<Index, 4> in_dimension;
in_tensor->shape().GetDimensions(&in_dimension);
llvm::SmallVector<Index, 4> out_dimension;
out_tensor->shape().GetDimensions(&out_dimension);
// step_size refers to # of elements of input tensor to be copied into the
// output tensor for each offset. step_size is determined by the axis.
//
// step_size = dimension[axis] * dimension[axis+1] * ... * dimension[n-1]
// when n is the # of dimensions.
// In other words, step_size is the product of dimensions starting from the
// axis to the end.
//
// Why this works:
// To gather on axis k means to gather all elements starting from that axis.
//
// Example:
// t0 has dimension 2x3
// t1 has dimension 2x3.
// Gathering on axis 0: step_size = dimension[0] * dimension[1] = 2*3 = 6
// This means the whole tensor is copied into the output after another.
// Gathering on axis 1: step_size = dimension[1] = 3
// This means 3 elements from each tensor is copied for each offset.
size_t step_size = 1;
// interval refers to the distance between offsets.
// interval is determined by the axis and output dimension.
//
// interval = out_dim[axis] * out_dim[axis+1] * ... * out_dim[n-1]
// when n is the # of dimensions and
// out_dim is the dimension of output tensor
//
// Why we need this:
// When gathering on a non-zero axis, a tensor is split into chunks. Each
// chunk, whose size is step_size, is copied into the output tensor at some
// offset. Each offset is some distance apart from each other. The distance
// is referred to as interval.
size_t interval = 1;
for (size_t j = *axis; j < in_dimension.size(); ++j) {
if (j != *axis && in_dimension[j] != out_dimension[j]) {
out_chain_indirect.SetError(
"Incorrect output dimension. All dimensions in the output must be "
"equal except the axis.");
return;
}
step_size *= in_dimension[j];
interval *= out_dimension[j];
}
size_t pos = 0;
for (size_t i = 0; i < kGroupSize; ++i) {
step_sizes.push_back(step_size);
size_t num_offsets = num_elements / step_size;
llvm::SmallVector<size_t, 4> offset;
size_t each_offset = pos;
for (size_t l = 0; l < num_offsets; ++l) {
offset.push_back(each_offset);
each_offset += interval;
}
offsets.push_back(offset);
pos += step_size;
}
}
EnqueueWork(exec_ctx,
[exec_ctx, my_index, instance_key = *instance_key, axis = *axis,
collective_group, dist_ctx = dist_ctx.ValueRef(),
in_tensor_ref = in_tensor.ValueRef(),
out_tensor_ref = out_tensor.ValueRef(),
refcounted_done = refcounted_done, offsets = std::move(offsets),
step_sizes = std::move(step_sizes)] {
DoAllGather<T>(exec_ctx, dist_ctx.CopyRef(), instance_key,
collective_group, my_index, in_tensor_ref.get(),
out_tensor_ref.get(), refcounted_done, axis,
offsets, step_sizes);
});
}