in tensorflow_networking/mpi_collectives/kernels/mpi_ops.cc [1044:1116]
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
const Tensor* input_tensor = &context->input(0);
const Tensor* sizing_tensor = &context->input(1);
// Record allocated on stack so op can fail without memory leak
CollectiveOpRecord record;
record.name = name();
record.context = context;
record.in_t = input_tensor;
record.on_gpu = IsGPUDevice<Device>();
// Construct the output size from the sizing tensor
size_t output_first_dim = 0;
if (sizing_tensor->shape().dims() == 0) {
// 0-dim sizing_tensor implies that the op is just gathering
// a single element from each rank
output_first_dim = mpi_global.size;
for (int i = 0; i < mpi_global.size; i++) {
record.sizes_vec.push_back(1);
}
} else {
// Collect the total output tensor sizing from the sizing tensor
// NOTE: The sizing tensor is forced to be placed on the CPU by
// declaring the input as HostMemory, so it is valid to read it here.
const int64* sizing_array =
(const int64*)sizing_tensor->tensor_data().data();
for (int i = 0; i < mpi_global.size; i++) {
record.sizes_vec.push_back(sizing_array[i]);
output_first_dim += sizing_array[i];
}
}
TensorShape output_shape;
output_shape.AddDim(output_first_dim);
for (int i = 1; i < input_tensor->shape().dims(); i++) {
output_shape.AddDim(input_tensor->shape().dim_size(i));
}
Tensor* output_tensor;
OP_REQUIRES_OK_ASYNC(
context, context->allocate_output(0, output_shape, &output_tensor),
done);
record.out_t = output_tensor;
record.dtype = input_tensor->dtype();
auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
context->SetStatus(status.status());
done();
};
record.callback = allgather_done_callback;
auto allgather_launch_callback = [record] {
EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
};
// If we are on a CPU, our device context will be null and we can't
// get a stream to enqueue this on. On a CPU this op is called when the
// data is already available, so we can just immediately do the
// allgather; we don't have to wait for the data to get populated.
#if GOOGLE_CUDA
auto device_context = context->op_device_context();
if (device_context == nullptr) {
allgather_launch_callback();
} else {
auto stream = device_context->stream();
stream->ThenDoHostCallback(allgather_launch_callback);
}
#else
allgather_launch_callback();
#endif
}