horovod/torch/ready_event.cc (88 lines of code) (raw):

// Copyright 2018 Uber Technologies, Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #if HAVE_GPU #if TORCH_VERSION >= 1005000000 #include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAException.h> #else #include <THC/THC.h> #endif #include <cassert> #include <mutex> #include <queue> #include <unordered_map> #endif #include "ready_event.h" #include "cuda_util.h" #if TORCH_VERSION < 1005000000 #if HAVE_GPU extern THCState* state; #endif #endif namespace horovod { namespace torch { #if HAVE_GPU struct ReadyEventRegistry { std::unordered_map<int, std::queue<cudaEvent_t>> cuda_events; std::mutex mutex; }; static ReadyEventRegistry ready_event_registry; TorchReadyEvent::TorchReadyEvent(int device) : device_(device) { assert(device_ != CPU_DEVICE_ID); with_device device_context(device_); { std::lock_guard<std::mutex> guard(ready_event_registry.mutex); auto& queue = ready_event_registry.cuda_events[device_]; if (!queue.empty()) { cuda_event_ = queue.front(); queue.pop(); } else { #if TORCH_VERSION >= 1005000000 C10_CUDA_CHECK(cudaEventCreateWithFlags( &cuda_event_, cudaEventBlockingSync | cudaEventDisableTiming)); #else THCudaCheck(cudaEventCreateWithFlags( &cuda_event_, cudaEventBlockingSync | cudaEventDisableTiming)); #endif } } #if TORCH_VERSION >= 1005000000 auto stream = c10::cuda::getCurrentCUDAStream(device_); C10_CUDA_CHECK(cudaEventRecord(cuda_event_, stream)); #else auto stream = THCState_getCurrentStreamOnDevice(state, device_); THCudaCheck(cudaEventRecord(cuda_event_, stream)); #endif } TorchReadyEvent::~TorchReadyEvent() { { std::lock_guard<std::mutex> guard(ready_event_registry.mutex); auto& queue = ready_event_registry.cuda_events[device_]; queue.push(cuda_event_); } } bool TorchReadyEvent::Ready() const { auto status = cudaEventQuery(cuda_event_); if (status == cudaErrorNotReady) { return false; } #if TORCH_VERSION >= 1005000000 C10_CUDA_CHECK(status); #else THCudaCheck(status); #endif return true; } #endif // On GPU this event will signal that GPU computations are done and data is // ready. std::shared_ptr<ReadyEvent> RecordReadyEvent(int device) { if (device == CPU_DEVICE_ID) { return std::shared_ptr<ReadyEvent>(); } else { #if HAVE_GPU return std::make_shared<TorchReadyEvent>(device); #else throw std::logic_error("Internal error. Requested ReadyEvent " "with GPU device but not compiled with CUDA."); #endif } } } // namespace torch } // namespace horovod