in tensorflow_networking/mpi_collectives/mpi_ops.cc [425:517]
void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
OpKernelContext* context;
const Tensor* input_tensor;
std::vector<size_t> sizes_vec;
Tensor temp_tensor;
Tensor* output_tensor;
CommunicationDoneCallback callback;
bool on_gpu;
{
// Lock on the tensor table.
mutex_lock guard(mpi_global.mu);
// We should never fail at finding this key in the tensor table.
auto name = response.tensor_name();
auto iter = tensor_table.find(name);
assert(iter != tensor_table.end());
assert(response.response_type() == MPIResponse::ALLREDUCE ||
response.response_type() == MPIResponse::ALLGATHER ||
response.response_type() == MPIResponse::ERROR);
CollectiveOpRecord record = iter->second;
context = record.context;
input_tensor = record.in_t;
sizes_vec = record.sizes_vec;
temp_tensor = record.temp_t;
output_tensor = record.out_t;
on_gpu = record.on_gpu;
callback = record.callback;
// Clear the tensor table of this tensor and its callbacks; the rest of
// this function takes care of it.
tensor_table.erase(iter);
}
// Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
// link to non-existent symbols.
#if GOOGLE_CUDA
#define GPU_DEVICE_IF_CUDA GPUDevice
#else
#define GPU_DEVICE_IF_CUDA CPUDevice
#endif
Status status;
auto dtype = input_tensor->dtype();
if (response.response_type() == MPIResponse::ALLGATHER) {
if (dtype == DT_FLOAT) {
status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
context, input_tensor, sizes_vec, output_tensor)
: RingAllgather<CPUDevice, float>(
context, input_tensor, sizes_vec, output_tensor);
} else if (dtype == DT_INT32) {
status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
context, input_tensor, sizes_vec, output_tensor)
: RingAllgather<CPUDevice, int>(context, input_tensor,
sizes_vec, output_tensor);
} else if (dtype == DT_INT64) {
status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
context, input_tensor, sizes_vec, output_tensor)
: RingAllgather<CPUDevice, long long>(
context, input_tensor, sizes_vec, output_tensor);
} else {
status = errors::Unknown("Invalid tensor type for MPI allgather.");
}
} else if (response.response_type() == MPIResponse::ALLREDUCE) {
if (dtype == DT_FLOAT) {
status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
context, input_tensor, &temp_tensor, output_tensor)
: RingAllreduce<CPUDevice, float>(
context, input_tensor, &temp_tensor, output_tensor);
} else if (dtype == DT_INT32) {
status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
context, input_tensor, &temp_tensor, output_tensor)
: RingAllreduce<CPUDevice, int>(
context, input_tensor, &temp_tensor, output_tensor);
} else if (dtype == DT_INT64) {
status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
context, input_tensor, &temp_tensor, output_tensor)
: RingAllreduce<CPUDevice, long long>(
context, input_tensor, &temp_tensor, output_tensor);
} else {
status = errors::Unknown("Invalid tensor type for MPI allreduce.");
}
} else if (response.response_type() == MPIResponse::ERROR) {
status = errors::FailedPrecondition(response.error_message());
}
if (status.ok()) {
callback(StatusOr<Tensor>(*output_tensor));
} else {
callback(StatusOr<Tensor>(status));
}
}