c10::intrusive_ptr DeviceFairring::allGather()

in fairring/device.cc [367:403]


c10::intrusive_ptr<c10::ivalue::Future> DeviceFairring::allGather(
    at::Tensor input,
    at::Tensor output) {
  // Because all-gather cannot be sliced
  MY_CHECK(numDevicesPerMachine_ >= 2);

  at::cuda::CUDAEvent initialEvent;
  initialEvent.record(c10::cuda::getCurrentCUDAStream(myDeviceIdxOnProcess_));

  c10::intrusive_ptr<c10::ivalue::Future> future =
      c10::make_intrusive<c10::ivalue::Future>(
          c10::ListType::ofTensors(),
          std::vector<c10::Device>(
              {c10::Device(c10::kCUDA, myDeviceIdxOnProcess_)}));

  cmdQueue_.enqueue([this,
                     input = std::move(input),
                     output = std::move(output),
                     initialEvent = std::make_shared<at::cuda::CUDAEvent>(
                         std::move(initialEvent)),
                     future]() mutable {
    MY_CHECK(kAlignment % input.element_size() == 0);
    int64_t seqNum = nextSlot_++;
    try {
      allGatherOneSlice(input, output, std::move(*initialEvent));
      c10::cuda::CUDAStreamGuard g(allGatherStream_);
      future->markCompleted(std::vector<at::Tensor>{output});
    } catch (const std::exception& e) {
      LOG(ERROR) << "Function for chunk #" << seqNum
                 << " threw exception: " << e.what();
      c10::cuda::CUDAStreamGuard g(allGatherStream_);
      future->setError(std::current_exception());
    }
  });

  return future;
}