in gloo/mpi/context.cc [88:149]
void Context::connectFullMesh(std::shared_ptr<transport::Device>& dev) {
std::vector<std::vector<char>> addresses(size);
unsigned long maxLength = 0;
int rv;
// Create pair to connect to every other node in the collective
auto transportContext = dev->createContext(rank, size);
transportContext->setTimeout(getTimeout());
for (int i = 0; i < size; i++) {
if (i == rank) {
continue;
}
auto& pair = transportContext->createPair(i);
// Store address for pair for this rank
auto address = pair->address().bytes();
maxLength = std::max(maxLength, address.size());
addresses[i] = std::move(address);
}
// Agree on maximum length so we can prepare buffers
rv = MPI_Allreduce(
MPI_IN_PLACE, &maxLength, 1, MPI_UNSIGNED_LONG, MPI_MAX, comm_);
if (rv != MPI_SUCCESS) {
GLOO_THROW_IO_EXCEPTION("MPI_Allreduce: ", rv);
}
// Prepare input and output
std::vector<char> in(size * maxLength);
std::vector<char> out(size * size * maxLength);
for (int i = 0; i < size; i++) {
if (i == rank) {
continue;
}
auto& address = addresses[i];
memcpy(in.data() + (i * maxLength), address.data(), address.size());
}
// Allgather to collect all addresses of all pairs
rv = MPI_Allgather(
in.data(), in.size(), MPI_BYTE, out.data(), in.size(), MPI_BYTE, comm_);
if (rv != MPI_SUCCESS) {
GLOO_THROW_IO_EXCEPTION("MPI_Allgather: ", rv);
}
// Connect every pair
for (int i = 0; i < size; i++) {
if (i == rank) {
continue;
}
auto offset = (rank + i * size) * maxLength;
std::vector<char> address(maxLength);
memcpy(address.data(), out.data() + offset, maxLength);
transportContext->getPair(i)->connect(address);
}
device_ = dev;
transportContext_ = std::move(transportContext);
}