in lib/distributed_runtime/kernels.cc [161:311]
void DoAllReduce(const ExecutionContext& exec_ctx,
AsyncValueRef<DistributedContext> dist_ctx,
const InstanceKey& instance_key,
const std::string& collective_group_name,
const DenseHostTensor& in_tensor,
const DenseHostTensor& out_tensor,
ElementWiseReductionFunction reduction_fn,
ElementWiseFinalFunction final_fn,
AsyncValueRef<Chain> out_chain) {
const auto& collective_group =
dist_ctx->GetCollectiveGroup(collective_group_name);
const int my_index =
FindMyIndex(collective_group.members, dist_ctx->GetTaskHandle());
if (my_index == -1) {
out_chain.SetError(StrCat("The current task ", dist_ctx->GetTaskName(),
" is not part of the collective group ",
collective_group_name));
return;
}
const size_t kGroupSize = collective_group.members.size();
const size_t kLastScatterStep = kGroupSize - 1;
const size_t kLastGatherStep = 2 * kGroupSize - 2;
const auto kPrefix = collective_group_name;
const int kTotalSteps = 2 * kGroupSize - 1;
const int neighbor_index = (my_index + 1) % collective_group.members.size();
const TaskHandle neighbor_task = collective_group.members[neighbor_index];
auto in_tensor_ref =
llvm::StringRef(reinterpret_cast<const char*>(in_tensor.data()),
in_tensor.DataSizeInBytes());
auto* callback_registry = dist_ctx->GetCallbackRegistry();
RemoteClientInterface* neighbor_client =
dist_ctx->GetRemoteClient(neighbor_task);
auto done = [out_chain = out_chain.CopyRef(),
dist_ctx = dist_ctx.CopyRef()](Error e) mutable {
if (e) {
out_chain.SetError(e);
} else {
out_chain.emplace();
}
};
// Ref counted callback to keep track of pending steps in all reduce.
// Add one ref before starting each step, and drop one ref when the step
// finishes (for steps with async RPCs, drop the reference when RPC finishes).
auto refcounted_done = TakeRef(
new RefCountedCallback([host = dist_ctx->GetHostContext(), exec_ctx,
done = std::move(done)](Error e) mutable {
// NOTE: we might be executing this in either HostContext work queue
// threads or the FabricCommunicator callback threads. Must make sure
// AsyncValue Chain gets emplaced (or set error) in the work queue
// threadpool, so that:
// * subsequent operations (i.e., AndThen) for this AsyncValue are
// executed in the work queue threads;
// * the AsyncValue drops its last ref and gets deallocated in the
// work queue threads
// Otherwise, the HostContext might get destroyed before the AsyncValue
// is deallocated or finishes its AndThen work, leading to segfault.
if (host->IsInWorkerThread()) {
done(std::move(e));
} else {
EnqueueWork(exec_ctx,
[done = std::move(done), e = std::move(e)]() mutable {
done(std::move(e));
});
}
}));
for (int step = 0; step < kTotalSteps; ++step) {
const InstanceKey step_key = StepKey(kPrefix, instance_key, step);
const InstanceKey next_step_key = StepKey(kPrefix, instance_key, step + 1);
const size_t split_id = SplitIndex(my_index, kGroupSize, step);
llvm::StringRef split_data = GetSplit<T>(in_tensor_ref, kGroupSize,
in_tensor.NumElements(), split_id);
auto request = std::make_unique<SendDataRequest>();
auto response = std::make_unique<SendDataResponse>();
request->set_context_id(dist_ctx->GetContextId());
request->set_instance_key(next_step_key);
if (step == 0) {
request->add_payload(split_data.data(), split_data.size());
neighbor_client->SendDataAsync(
RemoteCallContext::GetDefault(), request.get(), response.get(),
[request = std::move(request), response = std::move(response),
refcounted_done = refcounted_done](Error e) {
refcounted_done->UpdateState(std::move(e));
});
} else if (step <= kLastScatterStep) {
// Scatter stage: send a chunk to the neighbor, aggregate the incoming
// chunk with local buffer.
callback_registry->SetCallback(
step_key,
[step, in_split = split_data, out_split = split_data,
request = std::move(request), response = std::move(response),
neighbor_client, reduction_fn, final_fn, kLastScatterStep,
kGroupSize, refcounted_done = refcounted_done](
const InstanceKey&,
CallbackRegistry::CallbackValue callback_value) mutable {
RCReference<HostBuffer> data = callback_value.buffers[0];
// Scatter aggregates the results with the local buffer.
reduction_fn(static_cast<char*>(data->data()),
const_cast<char*>(in_split.data()), in_split.size());
if (step == kLastScatterStep) {
final_fn(static_cast<char*>(data->data()), in_split.size(),
kGroupSize);
std::copy(static_cast<char*>(data->data()),
static_cast<char*>(data->data()) + data->size(),
const_cast<char*>(out_split.begin()));
}
request->add_payload(data->data(), data->size());
neighbor_client->SendDataAsync(
RemoteCallContext::GetDefault(), request.get(), response.get(),
[request = std::move(request), response = std::move(response),
callback_value = std::move(callback_value),
refcounted_done = refcounted_done](Error e) mutable {
refcounted_done->UpdateState(std::move(e));
});
});
} else {
// Gather stage: an incoming chunk is final; just assign it to local
// buffer and pass it to the neighbor as is.
callback_registry->SetCallback(
step_key,
[step, out_split = split_data, kLastGatherStep,
request = std::move(request), response = std::move(response),
neighbor_client, refcounted_done = refcounted_done](
const InstanceKey&,
CallbackRegistry::CallbackValue callback_value) mutable {
RCReference<HostBuffer> data = callback_value.buffers[0];
// Gather assigns the incoming data to the local buffer
std::copy(static_cast<char*>(data->data()),
static_cast<char*>(data->data()) + data->size(),
const_cast<char*>(out_split.begin()));
if (step < kLastGatherStep) {
request->add_payload(data->data(), data->size());
neighbor_client->SendDataAsync(
RemoteCallContext::GetDefault(), request.get(),
response.get(),
[request = std::move(request), response = std::move(response),
callback_value = std::move(callback_value),
refcounted_done = refcounted_done](Error e) mutable {
refcounted_done->UpdateState(std::move(e));
});
}
});
}
}
}