void PerformCollectiveOp()

in tensorflow_networking/mpi_collectives/mpi_ops.cc [425:517]


void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
  OpKernelContext* context;
  const Tensor* input_tensor;
  std::vector<size_t> sizes_vec;
  Tensor temp_tensor;
  Tensor* output_tensor;
  CommunicationDoneCallback callback;
  bool on_gpu;
  {
    // Lock on the tensor table.
    mutex_lock guard(mpi_global.mu);

    // We should never fail at finding this key in the tensor table.
    auto name = response.tensor_name();
    auto iter = tensor_table.find(name);
    assert(iter != tensor_table.end());

    assert(response.response_type() == MPIResponse::ALLREDUCE ||
           response.response_type() == MPIResponse::ALLGATHER ||
           response.response_type() == MPIResponse::ERROR);

    CollectiveOpRecord record = iter->second;
    context = record.context;
    input_tensor = record.in_t;
    sizes_vec = record.sizes_vec;
    temp_tensor = record.temp_t;
    output_tensor = record.out_t;
    on_gpu = record.on_gpu;
    callback = record.callback;

    // Clear the tensor table of this tensor and its callbacks; the rest of
    // this function takes care of it.
    tensor_table.erase(iter);
  }

  // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
  // link to non-existent symbols.
#if GOOGLE_CUDA
#define GPU_DEVICE_IF_CUDA GPUDevice
#else
#define GPU_DEVICE_IF_CUDA CPUDevice
#endif

  Status status;
  auto dtype = input_tensor->dtype();
  if (response.response_type() == MPIResponse::ALLGATHER) {
    if (dtype == DT_FLOAT) {
      status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
                            context, input_tensor, sizes_vec, output_tensor)
                      : RingAllgather<CPUDevice, float>(
                            context, input_tensor, sizes_vec, output_tensor);
    } else if (dtype == DT_INT32) {
      status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
                            context, input_tensor, sizes_vec, output_tensor)
                      : RingAllgather<CPUDevice, int>(context, input_tensor,
                                                      sizes_vec, output_tensor);
    } else if (dtype == DT_INT64) {
      status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
                            context, input_tensor, sizes_vec, output_tensor)
                      : RingAllgather<CPUDevice, long long>(
                            context, input_tensor, sizes_vec, output_tensor);
    } else {
      status = errors::Unknown("Invalid tensor type for MPI allgather.");
    }
  } else if (response.response_type() == MPIResponse::ALLREDUCE) {
    if (dtype == DT_FLOAT) {
      status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
                            context, input_tensor, &temp_tensor, output_tensor)
                      : RingAllreduce<CPUDevice, float>(
                            context, input_tensor, &temp_tensor, output_tensor);
    } else if (dtype == DT_INT32) {
      status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
                            context, input_tensor, &temp_tensor, output_tensor)
                      : RingAllreduce<CPUDevice, int>(
                            context, input_tensor, &temp_tensor, output_tensor);
    } else if (dtype == DT_INT64) {
      status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
                            context, input_tensor, &temp_tensor, output_tensor)
                      : RingAllreduce<CPUDevice, long long>(
                            context, input_tensor, &temp_tensor, output_tensor);
    } else {
      status = errors::Unknown("Invalid tensor type for MPI allreduce.");
    }
  } else if (response.response_type() == MPIResponse::ERROR) {
    status = errors::FailedPrecondition(response.error_message());
  }

  if (status.ok()) {
    callback(StatusOr<Tensor>(*output_tensor));
  } else {
    callback(StatusOr<Tensor>(status));
  }
}