void allgatherv()

in gloo/allgatherv.cc [71:140]


void allgatherv(AllgathervOptions& opts) {
  const auto& context = opts.context;
  transport::UnboundBuffer* in = opts.in.get();
  transport::UnboundBuffer* out = opts.out.get();
  const auto slot = Slot::build(kAllgatherSlotPrefix, opts.tag);

  // Sanity checks
  GLOO_ENFORCE(opts.elementSize > 0);
  const auto recvRank = (context->size + context->rank - 1) % context->size;
  GLOO_ENFORCE(
      recvRank == context->rank || context->getPair(recvRank),
      "missing connection between rank " + std::to_string(context->rank) +
          " (this process) and rank " + std::to_string(recvRank));
  const auto sendRank = (context->size + context->rank + 1) % context->size;
  GLOO_ENFORCE(
      sendRank == context->rank || context->getPair(sendRank),
      "missing connection between rank " + std::to_string(context->rank) +
          " (this process) and rank " + std::to_string(sendRank));

  // Compute byte counts and offsets into output buffer.
  std::vector<size_t> byteCounts;
  std::vector<size_t> byteOffsets;
  byteCounts.reserve(context->size);
  byteOffsets.reserve(context->size);
  size_t offset = 0;
  for (const auto& elements : opts.elements) {
    const auto bytes = elements * opts.elementSize;
    byteCounts.push_back(bytes);
    byteOffsets.push_back(offset);
    offset += bytes;
  }

  // If the input buffer is specified, the output buffer needs to be primed.
  if (in != nullptr) {
    GLOO_ENFORCE_EQ(byteCounts[context->rank], in->size);
    if (byteCounts[context->rank] > 0) {
      memcpy(
          static_cast<uint8_t*>(out->ptr) + byteOffsets[context->rank],
          static_cast<uint8_t*>(in->ptr),
          in->size);
    }
  }

  // Short circuit if there is only a single process.
  if (context->size == 1) {
    return;
  }

  const auto baseIndex = context->size + context->rank;
  for (auto i = 0; i < context->size - 1; i++) {
    const size_t sendIndex = (baseIndex - i) % context->size;
    const size_t recvIndex = (baseIndex - i - 1) % context->size;

    if (i == 0) {
      out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]);
      out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]);
      continue;
    }

    // Wait for previous operations to complete before kicking off new ones.
    out->waitSend(opts.timeout);
    out->waitRecv(opts.timeout);
    out->send(sendRank, slot, byteOffsets[sendIndex], byteCounts[sendIndex]);
    out->recv(recvRank, slot, byteOffsets[recvIndex], byteCounts[recvIndex]);
  }

  // Wait for final operations to complete.
  out->waitSend(opts.timeout);
  out->waitRecv(opts.timeout);
}