torchaudio/csrc/rnnt/gpu/math.cuh (36 lines of code) (raw):

#pragma once #ifdef USE_CUDA #include <cmath> #endif // USE_CUDA #include <torchaudio/csrc/rnnt/gpu/half.cuh> namespace torchaudio { namespace rnnt { namespace math { template <typename DTYPE> FORCE_INLINE HOST_AND_DEVICE DTYPE max(DTYPE x, DTYPE y) { if (x > y) return x; else return y; } template <typename DTYPE> FORCE_INLINE HOST_AND_DEVICE DTYPE min(DTYPE x, DTYPE y) { if (x > y) return y; else return x; } // log_sum_exp template <typename DTYPE> FORCE_INLINE HOST_AND_DEVICE DTYPE lse(DTYPE x, DTYPE y); template <> FORCE_INLINE HOST_AND_DEVICE float lse(float x, float y) { if (y > x) { return y + log1pf(expf(x - y)); } else { return x + log1pf(expf(y - x)); } } } // namespace math } // namespace rnnt } // namespace torchaudio