in src/torch_ucc.cpp [570:614]
void CommPG::progress_loop() {
std::unique_lock<std::mutex> lock(mutex);
#ifdef USE_CUDA
bool device_set = false;
#endif
while (!stop_progress_loop) {
if (progress_queue.empty()) {
queue_produce_cv.wait(lock);
continue;
}
collective_inprogress = true;
auto work = progress_queue.front();
progress_queue.pop_front();
lock.unlock();
#ifdef USE_CUDA
if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
c10::cuda::set_device(cuda_device_index);
device_set = true;
}
#endif
std::exception_ptr eptr;
try {
while (work->request_->status > 0) {
ucc_comm.progress();
ucx_comm.progress();
}
if (work->request_->status < 0) {
eptr = std::make_exception_ptr(
std::runtime_error(ucc_status_string(work->request_->status)));
std::string err_log = c10::str(
"Failed to progress communication", // TODO: report exact op type or
// id?
ucc_status_string(work->request_->status));
TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
}
} catch (...) {
eptr = std::current_exception();
}
work->finalize(eptr);
work = nullptr;
collective_inprogress = false;
queue_consume_cv.notify_one();
lock.lock();
}
}