void allreduce()

in gloo/allreduce.cc [97:146]


void allreduce(const detail::AllreduceOptionsImpl& opts) {
  if (opts.elements == 0) {
    return;
  }

  const auto& context = opts.context;
  const std::vector<std::unique_ptr<transport::UnboundBuffer>>& in = opts.in;
  const std::vector<std::unique_ptr<transport::UnboundBuffer>>& out = opts.out;
  const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag);

  // Sanity checks
  GLOO_ENFORCE_GT(out.size(), 0);
  GLOO_ENFORCE(opts.elementSize > 0);
  GLOO_ENFORCE(opts.reduce != nullptr);

  // Assert the size of all inputs and outputs is identical.
  const size_t totalBytes = opts.elements * opts.elementSize;
  for (size_t i = 0; i < out.size(); i++) {
    GLOO_ENFORCE_EQ(out[i]->size, totalBytes);
  }
  for (size_t i = 0; i < in.size(); i++) {
    GLOO_ENFORCE_EQ(in[i]->size, totalBytes);
  }

  // Initialize local reduction and broadcast functions.
  // Note that these are a no-op if only a single output is specified
  // and is used as both input and output.
  const auto reduceInputs =
      genLocalReduceFunction(in, out, opts.elementSize, opts.reduce);
  const auto broadcastOutputs = genLocalBroadcastFunction(out);

  // Simple circuit if there is only a single process.
  if (context->size == 1) {
    reduceInputs(0, totalBytes);
    broadcastOutputs(0, totalBytes);
    return;
  }

  switch (opts.algorithm) {
    case detail::AllreduceOptionsImpl::UNSPECIFIED:
    case detail::AllreduceOptionsImpl::RING:
      ring(opts, reduceInputs, broadcastOutputs);
      break;
    case detail::AllreduceOptionsImpl::BCUBE:
      bcube(opts, reduceInputs, broadcastOutputs);
      break;
    default:
      GLOO_ENFORCE(false, "Algorithm not handled.");
  }
}