void BackgroundThreadLoop()

in tensorflow_networking/mpi_collectives/kernels/mpi_ops.cc [573:795]


void BackgroundThreadLoop() {
#if GOOGLE_CUDA
  // Set the device, so that this thread uses the same GPU context as the
  // calling thread.
  // TODO: Ensure that this is operating correctly. The background thread
  // needs to be able to control all GPUs that the rank has access to, and
  // might be more than 1 GPU. Tensors could be resident in any of the
  // GPUs, so the background thread's accumulate and copy kernels might need
  // to correctly set the device and it might be necessary for the background
  // thread to manage multiple streams.
  cudaSetDevice(mpi_global.device);
  cudaStreamCreate(&mpi_global.stream);
#endif

  // Initialize MPI. This must happen on the background thread, since not all
  // MPI implementations support being called from multiple threads.
  auto init_result = MPI_Init(NULL, NULL);
  if (init_result != MPI_SUCCESS) {
    mpi_global.init_status =
        errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
    mpi_global.initialization_done = true;
    mpi_global.cv.notify_all();
    return;
  } else {
    mpi_global.init_status = Status::OK();
  }

  // Get MPI rank to determine if we are rank zero.
  int rank;
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  bool is_coordinator = rank == 0;

  // Get MPI size to determine how many tensors to wait for before reducing.
  int size;
  MPI_Comm_size(MPI_COMM_WORLD, &size);

  // Determine local rank by querying the local communicator.
  MPI_Comm local_comm;
  MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
                      &local_comm);
  int local_rank;
  MPI_Comm_rank(local_comm, &local_rank);

  mpi_global.rank = rank;
  mpi_global.local_rank = local_rank;
  mpi_global.size = size;
  mpi_global.initialization_done = true;

  // Notify calling thread that initialization is complete
  mpi_global.cv.notify_all();

  // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
  // Initialize the tensor count table. No tensors are available yet.
  if (is_coordinator) {
    mpi_global.message_table =
        std::unique_ptr<MessageTable>(new MessageTable());
  }

  // The coordinator sends a SHUTDOWN message to trigger shutdown.
  bool should_shut_down = false;
  do {
    // TODO: Eliminate the need for thread sleep by making all activity
    // depend on other activity (e.g. condition or MPI waits).
    std::this_thread::sleep_for(std::chrono::milliseconds(1));

    // Copy the data structures from global state under this lock.
    // However, don't keep the lock for the rest of the loop, so that
    // enqueued stream callbacks can continue.
    std::queue<MPIRequest> message_queue;
    {
      mutex_lock guard(mpi_global.mu);
      while (!mpi_global.message_queue.empty()) {
        MPIRequest message = mpi_global.message_queue.front();
        mpi_global.message_queue.pop();
        message_queue.push(message);
      }
    }

    // Collect all tensors that are ready to be reduced. Record them in the
    // tensor count table (rank zero) or send them to rank zero to be
    // recorded (everyone else).
    std::vector<std::string> ready_to_reduce;
    while (!message_queue.empty()) {
      // Pop the first available message message
      MPIRequest message = message_queue.front();
      message_queue.pop();

      if (is_coordinator) {
        bool reduce =
            IncrementTensorCount(mpi_global.message_table, message, size);
        if (reduce) {
          ready_to_reduce.push_back(message.tensor_name());
        }
      } else {
        std::string encoded_message;
        message.SerializeToString(&encoded_message);
        MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
                 MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
      }
    }

    // Rank zero has put all its own tensors in the tensor count table.
    // Now, it should count all the tensors that are coming from other
    // ranks at this tick. It should keep getting tensors until it gets a
    // DONE message from all the other ranks.
    if (is_coordinator) {
      // Count of DONE messages. Keep receiving messages until the number
      // of messages is equal to the number of processes. Initialize to
      // one since the coordinator is effectively done.
      int completed_ranks = 1;
      while (completed_ranks != size) {
        MPI_Status status;
        MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);

        // Find number of characters in message (including zero byte).
        int source_rank = status.MPI_SOURCE;
        int msg_length;
        MPI_Get_count(&status, MPI_BYTE, &msg_length);

        // If the length is zero, this is a DONE message.
        if (msg_length == 0) {
          completed_ranks++;
          MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
                   &status);
          continue;
        }

        // Get tensor name from MPI into an std::string.
        char* buffer = new char[msg_length];
        MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
                 MPI_COMM_WORLD, &status);
        std::string received_data(buffer);
        delete[] buffer;

        MPIRequest received_message;
        received_message.ParseFromString(received_data);
        auto received_name = received_message.tensor_name();

        bool reduce = IncrementTensorCount(mpi_global.message_table,
                                           received_message, size);
        if (reduce) {
          ready_to_reduce.push_back(received_name);
        }
      }

      // At this point, rank zero should have a fully updated tensor
      // count table and should know all the tensors that need to be
      // reduced or gathered, and everyone else should have sent all
      // their information to rank zero. We can now do reductions and
      // gathers; rank zero will choose which ones and in what order,
      // and will notify the other ranks before doing each reduction.
      for (int i = 0; i < ready_to_reduce.size(); i++) {
        // Notify all nodes which tensor we'd like to reduce now
        auto name = ready_to_reduce[i];
        MPIResponse response =
            ConstructMPIResponse(mpi_global.message_table, name);

        std::string encoded_response;
        response.SerializeToString(&encoded_response);
        for (int r = 1; r < size; r++) {
          MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
                   MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
        }

        // Perform the reduction. All nodes should end up performing
        // the same reduction.
        PerformCollectiveOp(mpi_global.tensor_table, response);
      }

      // Notify all nodes that we are done with the reductions for this
      // tick.
      MPIResponse done_response;
      should_shut_down = mpi_global.shut_down;
      done_response.set_response_type(
          mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
      std::string encoded_response;
      done_response.SerializeToString(&encoded_response);
      for (int r = 1; r < size; r++) {
        MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
                 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
      }
    } else {
      // Notify the coordinator that this node is done sending messages.
      // A DONE message is encoded as a zero-length message.
      MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);

      // Receive names for tensors to reduce from rank zero. Once we
      // receive a empty DONE message, stop waiting for more names.
      while (true) {
        MPI_Status status;
        MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);

        // Find number of characters in message (including zero byte).
        int msg_length;
        MPI_Get_count(&status, MPI_BYTE, &msg_length);

        // Get tensor name from MPI into an std::string.
        char* buffer = new char[msg_length];
        MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
                 &status);
        std::string received_message(buffer);
        delete[] buffer;

        MPIResponse response;
        response.ParseFromString(received_message);
        if (response.response_type() == MPIResponse::DONE) {
          // No more messages this tick
          break;
        } else if (response.response_type() == MPIResponse::SHUTDOWN) {
          // No more messages this tick, and the background thread
          // should shut down
          should_shut_down = true;
          break;
        } else {
          // Process the current message
          PerformCollectiveOp(mpi_global.tensor_table, response);
        }
      }
    }
  } while (!should_shut_down);

  MPI_Finalize();
}