MPIResponse ConstructMPIResponse()

in tensorflow_networking/mpi_collectives/kernels/mpi_ops.cc [294:421]


MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
                                 std::string name) {
  bool error = false;
  auto it = message_table->find(name);
  assert(it != message_table->end());

  std::vector<MPIRequest> requests = it->second;
  assert(requests.size() > 0);

  std::ostringstream error_message_stream;

  // Check that all data types being reduced or gathered are identical
  auto data_type = requests[0].tensor_type();
  for (unsigned int i = 1; i < requests.size(); i++) {
    auto request_type = requests[i].tensor_type();
    if (data_type != request_type) {
      error = true;
      error_message_stream << "Mismatched data types: One rank had type "
                           << DataType_Name(data_type)
                           << ", but another rank had type "
                           << DataType_Name(request_type) << ".";
      break;
    }
  }

  // Check that all requested operations are the same
  auto message_type = requests[0].request_type();
  for (unsigned int i = 1; i < requests.size(); i++) {
    if (error) {
      break;
    }

    auto request_type = requests[i].request_type();
    if (message_type != request_type) {
      error = true;
      error_message_stream << "Mismatched MPI operations: One rank did an "
                           << message_type << ", but another rank did an "
                           << request_type << ".";
      break;
    }
  }

  // If we are doing an allreduce, check that all tensor shapes
  // are identical
  if (message_type == MPIRequest::ALLREDUCE) {
    TensorShape tensor_shape = requests[0].tensor_shape();
    for (unsigned int i = 1; i < requests.size(); i++) {
      if (error) {
        break;
      }

      TensorShape request_shape = requests[i].tensor_shape();
      if (tensor_shape != request_shape) {
        error = true;
        error_message_stream << "Mismatched allreduce tensor shapes: "
                             << "One rank reduced a tensor of shape "
                             << tensor_shape.DebugString()
                             << ", but another rank sent a tensor of shape "
                             << request_shape.DebugString() << ".";
        break;
      }
    }
  }

  // If we are doing an allgather, make sure all but the first dimension are
  // the same. The first dimension may be different and the output tensor is
  // the sum of the first dimension. Collect the sizes by rank.
  if (message_type == MPIRequest::ALLGATHER) {
    TensorShape tensor_shape = requests[0].tensor_shape();

    if (tensor_shape.dims() == 0) {
      error = true;
      error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
    }

    for (unsigned int i = 1; i < requests.size(); i++) {
      if (error) {
        break;
      }

      TensorShape request_shape = requests[i].tensor_shape();
      if (tensor_shape.dims() != request_shape.dims()) {
        error = true;
        error_message_stream << "Mismatched allgather tensor shapes: "
                             << "One rank gathered a tensor of rank "
                             << tensor_shape.dims()
                             << ", but another rank sent a tensor of rank "
                             << request_shape.dims() << ".";
        break;
      }

      for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
        if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
          error = true;
          error_message_stream
              << "Mismatched allgather tensor shapes: "
              << "One rank gathered a tensor with dimension " << dim
              << " equal to " << tensor_shape.dim_size(dim)
              << ", but another rank sent a tensor with dimension " << dim
              << " equal to " << request_shape.dim_size(dim) << ".";
          break;
        }
      }
    }
  }

  MPIResponse response;
  response.set_tensor_name(name);
  if (error) {
    std::string error_message = error_message_stream.str();
    response.set_response_type(MPIResponse::ERROR);
    response.set_error_message(error_message);
  } else {
    auto response_type = MPIResponse::ERROR;
    if (message_type == MPIRequest::ALLREDUCE) {
      response_type = MPIResponse::ALLREDUCE;
    } else {
      response_type = MPIResponse::ALLGATHER;
    }
    response.set_response_type(response_type);
  }

  // Clear all queued up requests for this name. They are now taken care of
  // by the constructed MPI response.
  message_table->erase(it);

  return response;
}