void DoAllReduce()

in lib/distributed_runtime/kernels.cc [161:311]


void DoAllReduce(const ExecutionContext& exec_ctx,
                 AsyncValueRef<DistributedContext> dist_ctx,
                 const InstanceKey& instance_key,
                 const std::string& collective_group_name,
                 const DenseHostTensor& in_tensor,
                 const DenseHostTensor& out_tensor,
                 ElementWiseReductionFunction reduction_fn,
                 ElementWiseFinalFunction final_fn,
                 AsyncValueRef<Chain> out_chain) {
  const auto& collective_group =
      dist_ctx->GetCollectiveGroup(collective_group_name);
  const int my_index =
      FindMyIndex(collective_group.members, dist_ctx->GetTaskHandle());
  if (my_index == -1) {
    out_chain.SetError(StrCat("The current task ", dist_ctx->GetTaskName(),
                              " is not part of the collective group ",
                              collective_group_name));
    return;
  }
  const size_t kGroupSize = collective_group.members.size();
  const size_t kLastScatterStep = kGroupSize - 1;
  const size_t kLastGatherStep = 2 * kGroupSize - 2;
  const auto kPrefix = collective_group_name;
  const int kTotalSteps = 2 * kGroupSize - 1;

  const int neighbor_index = (my_index + 1) % collective_group.members.size();
  const TaskHandle neighbor_task = collective_group.members[neighbor_index];

  auto in_tensor_ref =
      llvm::StringRef(reinterpret_cast<const char*>(in_tensor.data()),
                      in_tensor.DataSizeInBytes());
  auto* callback_registry = dist_ctx->GetCallbackRegistry();
  RemoteClientInterface* neighbor_client =
      dist_ctx->GetRemoteClient(neighbor_task);

  auto done = [out_chain = out_chain.CopyRef(),
               dist_ctx = dist_ctx.CopyRef()](Error e) mutable {
    if (e) {
      out_chain.SetError(e);
    } else {
      out_chain.emplace();
    }
  };

  // Ref counted callback to keep track of pending steps in all reduce.
  // Add one ref before starting each step, and drop one ref when the step
  // finishes (for steps with async RPCs, drop the reference when RPC finishes).
  auto refcounted_done = TakeRef(
      new RefCountedCallback([host = dist_ctx->GetHostContext(), exec_ctx,
                              done = std::move(done)](Error e) mutable {
        // NOTE: we might be executing this in either HostContext work queue
        // threads or the FabricCommunicator callback threads. Must make sure
        // AsyncValue Chain gets emplaced (or set error) in the work queue
        // threadpool, so that:
        //   * subsequent operations (i.e., AndThen) for this AsyncValue are
        //     executed in the work queue threads;
        //   * the AsyncValue drops its last ref and gets deallocated in the
        //     work queue threads
        // Otherwise, the HostContext might get destroyed before the AsyncValue
        // is deallocated or finishes its AndThen work, leading to segfault.
        if (host->IsInWorkerThread()) {
          done(std::move(e));
        } else {
          EnqueueWork(exec_ctx,
                      [done = std::move(done), e = std::move(e)]() mutable {
                        done(std::move(e));
                      });
        }
      }));

  for (int step = 0; step < kTotalSteps; ++step) {
    const InstanceKey step_key = StepKey(kPrefix, instance_key, step);
    const InstanceKey next_step_key = StepKey(kPrefix, instance_key, step + 1);
    const size_t split_id = SplitIndex(my_index, kGroupSize, step);
    llvm::StringRef split_data = GetSplit<T>(in_tensor_ref, kGroupSize,
                                             in_tensor.NumElements(), split_id);
    auto request = std::make_unique<SendDataRequest>();
    auto response = std::make_unique<SendDataResponse>();
    request->set_context_id(dist_ctx->GetContextId());
    request->set_instance_key(next_step_key);

    if (step == 0) {
      request->add_payload(split_data.data(), split_data.size());
      neighbor_client->SendDataAsync(
          RemoteCallContext::GetDefault(), request.get(), response.get(),
          [request = std::move(request), response = std::move(response),
           refcounted_done = refcounted_done](Error e) {
            refcounted_done->UpdateState(std::move(e));
          });
    } else if (step <= kLastScatterStep) {
      // Scatter stage: send a chunk to the neighbor, aggregate the incoming
      // chunk with local buffer.
      callback_registry->SetCallback(
          step_key,
          [step, in_split = split_data, out_split = split_data,
           request = std::move(request), response = std::move(response),
           neighbor_client, reduction_fn, final_fn, kLastScatterStep,
           kGroupSize, refcounted_done = refcounted_done](
              const InstanceKey&,
              CallbackRegistry::CallbackValue callback_value) mutable {
            RCReference<HostBuffer> data = callback_value.buffers[0];
            // Scatter aggregates the results with the local buffer.
            reduction_fn(static_cast<char*>(data->data()),
                         const_cast<char*>(in_split.data()), in_split.size());

            if (step == kLastScatterStep) {
              final_fn(static_cast<char*>(data->data()), in_split.size(),
                       kGroupSize);
              std::copy(static_cast<char*>(data->data()),
                        static_cast<char*>(data->data()) + data->size(),
                        const_cast<char*>(out_split.begin()));
            }
            request->add_payload(data->data(), data->size());
            neighbor_client->SendDataAsync(
                RemoteCallContext::GetDefault(), request.get(), response.get(),
                [request = std::move(request), response = std::move(response),
                 callback_value = std::move(callback_value),
                 refcounted_done = refcounted_done](Error e) mutable {
                  refcounted_done->UpdateState(std::move(e));
                });
          });
    } else {
      // Gather stage: an incoming chunk is final; just assign it to local
      // buffer and pass it to the neighbor as is.
      callback_registry->SetCallback(
          step_key,
          [step, out_split = split_data, kLastGatherStep,
           request = std::move(request), response = std::move(response),
           neighbor_client, refcounted_done = refcounted_done](
              const InstanceKey&,
              CallbackRegistry::CallbackValue callback_value) mutable {
            RCReference<HostBuffer> data = callback_value.buffers[0];
            // Gather assigns the incoming data to the local buffer
            std::copy(static_cast<char*>(data->data()),
                      static_cast<char*>(data->data()) + data->size(),
                      const_cast<char*>(out_split.begin()));
            if (step < kLastGatherStep) {
              request->add_payload(data->data(), data->size());
              neighbor_client->SendDataAsync(
                  RemoteCallContext::GetDefault(), request.get(),
                  response.get(),
                  [request = std::move(request), response = std::move(response),
                   callback_value = std::move(callback_value),
                   refcounted_done = refcounted_done](Error e) mutable {
                    refcounted_done->UpdateState(std::move(e));
                  });
            }
          });
    }
  }
}