in tensorflow_networking/mpi_collectives/mpi_ops.cc [574:796]
void BackgroundThreadLoop() {
#if GOOGLE_CUDA
// Set the device, so that this thread uses the same GPU context as the
// calling thread.
// TODO: Ensure that this is operating correctly. The background thread
// needs to be able to control all GPUs that the rank has access to, and
// might be more than 1 GPU. Tensors could be resident in any of the
// GPUs, so the background thread's accumulate and copy kernels might need
// to correctly set the device and it might be necessary for the background
// thread to manage multiple streams.
cudaSetDevice(mpi_global.device);
cudaStreamCreate(&mpi_global.stream);
#endif
// Initialize MPI. This must happen on the background thread, since not all
// MPI implementations support being called from multiple threads.
auto init_result = MPI_Init(NULL, NULL);
if (init_result != MPI_SUCCESS) {
mpi_global.init_status =
errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
mpi_global.initialization_done = true;
mpi_global.cv.notify_all();
return;
} else {
mpi_global.init_status = Status::OK();
}
// Get MPI rank to determine if we are rank zero.
int rank;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
bool is_coordinator = rank == 0;
// Get MPI size to determine how many tensors to wait for before reducing.
int size;
MPI_Comm_size(MPI_COMM_WORLD, &size);
// Determine local rank by querying the local communicator.
MPI_Comm local_comm;
MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
&local_comm);
int local_rank;
MPI_Comm_rank(local_comm, &local_rank);
mpi_global.rank = rank;
mpi_global.local_rank = local_rank;
mpi_global.size = size;
mpi_global.initialization_done = true;
// Notify calling thread that initialization is complete
mpi_global.cv.notify_all();
// TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
// Initialize the tensor count table. No tensors are available yet.
if (is_coordinator) {
mpi_global.message_table =
std::unique_ptr<MessageTable>(new MessageTable());
}
// The coordinator sends a SHUTDOWN message to trigger shutdown.
bool should_shut_down = false;
do {
// TODO: Eliminate the need for thread sleep by making all activity
// depend on other activity (e.g. condition or MPI waits).
std::this_thread::sleep_for(std::chrono::milliseconds(1));
// Copy the data structures from global state under this lock.
// However, don't keep the lock for the rest of the loop, so that
// enqueued stream callbacks can continue.
std::queue<MPIRequest> message_queue;
{
mutex_lock guard(mpi_global.mu);
while (!mpi_global.message_queue.empty()) {
MPIRequest message = mpi_global.message_queue.front();
mpi_global.message_queue.pop();
message_queue.push(message);
}
}
// Collect all tensors that are ready to be reduced. Record them in the
// tensor count table (rank zero) or send them to rank zero to be
// recorded (everyone else).
std::vector<std::string> ready_to_reduce;
while (!message_queue.empty()) {
// Pop the first available message message
MPIRequest message = message_queue.front();
message_queue.pop();
if (is_coordinator) {
bool reduce =
IncrementTensorCount(mpi_global.message_table, message, size);
if (reduce) {
ready_to_reduce.push_back(message.tensor_name());
}
} else {
std::string encoded_message;
message.SerializeToString(&encoded_message);
MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
}
}
// Rank zero has put all its own tensors in the tensor count table.
// Now, it should count all the tensors that are coming from other
// ranks at this tick. It should keep getting tensors until it gets a
// DONE message from all the other ranks.
if (is_coordinator) {
// Count of DONE messages. Keep receiving messages until the number
// of messages is equal to the number of processes. Initialize to
// one since the coordinator is effectively done.
int completed_ranks = 1;
while (completed_ranks != size) {
MPI_Status status;
MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
// Find number of characters in message (including zero byte).
int source_rank = status.MPI_SOURCE;
int msg_length;
MPI_Get_count(&status, MPI_BYTE, &msg_length);
// If the length is zero, this is a DONE message.
if (msg_length == 0) {
completed_ranks++;
MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
&status);
continue;
}
// Get tensor name from MPI into an std::string.
char* buffer = new char[msg_length];
MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
MPI_COMM_WORLD, &status);
std::string received_data(buffer);
delete[] buffer;
MPIRequest received_message;
received_message.ParseFromString(received_data);
auto received_name = received_message.tensor_name();
bool reduce = IncrementTensorCount(mpi_global.message_table,
received_message, size);
if (reduce) {
ready_to_reduce.push_back(received_name);
}
}
// At this point, rank zero should have a fully updated tensor
// count table and should know all the tensors that need to be
// reduced or gathered, and everyone else should have sent all
// their information to rank zero. We can now do reductions and
// gathers; rank zero will choose which ones and in what order,
// and will notify the other ranks before doing each reduction.
for (int i = 0; i < ready_to_reduce.size(); i++) {
// Notify all nodes which tensor we'd like to reduce now
auto name = ready_to_reduce[i];
MPIResponse response =
ConstructMPIResponse(mpi_global.message_table, name);
std::string encoded_response;
response.SerializeToString(&encoded_response);
for (int r = 1; r < size; r++) {
MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
}
// Perform the reduction. All nodes should end up performing
// the same reduction.
PerformCollectiveOp(mpi_global.tensor_table, response);
}
// Notify all nodes that we are done with the reductions for this
// tick.
MPIResponse done_response;
should_shut_down = mpi_global.shut_down;
done_response.set_response_type(
mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
std::string encoded_response;
done_response.SerializeToString(&encoded_response);
for (int r = 1; r < size; r++) {
MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
}
} else {
// Notify the coordinator that this node is done sending messages.
// A DONE message is encoded as a zero-length message.
MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
// Receive names for tensors to reduce from rank zero. Once we
// receive a empty DONE message, stop waiting for more names.
while (true) {
MPI_Status status;
MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
// Find number of characters in message (including zero byte).
int msg_length;
MPI_Get_count(&status, MPI_BYTE, &msg_length);
// Get tensor name from MPI into an std::string.
char* buffer = new char[msg_length];
MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
&status);
std::string received_message(buffer);
delete[] buffer;
MPIResponse response;
response.ParseFromString(received_message);
if (response.response_type() == MPIResponse::DONE) {
// No more messages this tick
break;
} else if (response.response_type() == MPIResponse::SHUTDOWN) {
// No more messages this tick, and the background thread
// should shut down
should_shut_down = true;
break;
} else {
// Process the current message
PerformCollectiveOp(mpi_global.tensor_table, response);
}
}
}
} while (!should_shut_down);
MPI_Finalize();
}