void ComputeAsync()

in tensorflow_networking/mpi_collectives/mpi_ops.cc [1120:1192]


  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
    const Tensor* input_tensor = &context->input(0);
    const Tensor* sizing_tensor = &context->input(1);

    // Record allocated on stack so op can fail without memory leak
    CollectiveOpRecord record;
    record.name = name();
    record.context = context;
    record.in_t = input_tensor;
    record.on_gpu = IsGPUDevice<Device>();

    // Construct the output size from the sizing tensor
    size_t output_first_dim = 0;
    if (sizing_tensor->shape().dims() == 0) {
      // 0-dim sizing_tensor implies that the op is just gathering
      // a single element from each rank
      output_first_dim = mpi_global.size;
      for (int i = 0; i < mpi_global.size; i++) {
        record.sizes_vec.push_back(1);
      }
    } else {
      // Collect the total output tensor sizing from the sizing tensor
      // NOTE: The sizing tensor is forced to be placed on the CPU by
      // declaring the input as HostMemory, so it is valid to read it here.
      const int64* sizing_array =
          (const int64*)sizing_tensor->tensor_data().data();
      for (int i = 0; i < mpi_global.size; i++) {
        record.sizes_vec.push_back(sizing_array[i]);
        output_first_dim += sizing_array[i];
      }
    }

    TensorShape output_shape;
    output_shape.AddDim(output_first_dim);
    for (int i = 1; i < input_tensor->shape().dims(); i++) {
      output_shape.AddDim(input_tensor->shape().dim_size(i));
    }

    Tensor* output_tensor;
    OP_REQUIRES_OK_ASYNC(
        context, context->allocate_output(0, output_shape, &output_tensor),
        done);

    record.out_t = output_tensor;
    record.dtype = input_tensor->dtype();

    auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
      context->SetStatus(status.status());
      done();
    };
    record.callback = allgather_done_callback;

    auto allgather_launch_callback = [record] {
      EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
    };

    // If we are on a CPU, our device context will be null and we can't
    // get a stream to enqueue this on. On a CPU this op is called when the
    // data is already available, so we can just immediately do the
    // allgather; we don't have to wait for the data to get populated.
#if GOOGLE_CUDA
    auto device_context = context->op_device_context();
    if (device_context == nullptr) {
      allgather_launch_callback();
    } else {
      auto stream = device_context->stream();
      stream->ThenDoHostCallback(allgather_launch_callback);
    }
#else
    allgather_launch_callback();
#endif
  }