in src/torch_ucc.cpp [703:762]
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
OpType opType,
PreProcess preproc,
PostProcess postproc,
ucc_coll_args_t& coll,
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::Device dev,
std::vector<at::Tensor> &outputTensors,
const char* prof_title) {
set_timeout(coll);
auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
opType, torch_ucc_config.enable_profiling ? prof_title : nullptr);
// Store references to outputs to be used by result
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
switch (dev.type()) {
case c10::DeviceType::CPU: {
if (torch_ucc_config.use_future) {
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
}
comm->enqueue_collective(std::move(data), work, coll, team);
return work;
}
#ifdef USE_CUDA
case c10::DeviceType::CUDA: {
auto cuda_ev = getPooledEvent();
cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
cuda_ev->block(*stream);
at::cuda::CUDAStreamGuard guard(*stream);
preproc();
comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee);
postproc();
cuda_ev->record(*stream);
work->fence = std::move(cuda_ev);
work->ep = &ep;
if (torch_ucc_config.use_future) {
c10::cuda::CUDAMultiStreamGuard streamGuard(*stream);
std::vector<c10::Device> devList{dev};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devList);
// Add a callback that runs profiling end callbacks
if (work->recordFunctionEndCallback_) {
work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
work->recordFunctionEndCallback_();
});
}
work->future_->markCompleted(c10::IValue(outputTensors));
}
return work;
}
#endif // #ifdef USE_CUDA
default: {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
}
}