torchaudio/csrc/rnnt/cpu/kernel_utils.h (51 lines of code) (raw):

#pragma once #include <cassert> #include <torchaudio/csrc/rnnt/cpu/math.h> namespace torchaudio { namespace rnnt { inline HOST_AND_DEVICE bool in_range( int start, int end, // inclusive int val) { return start <= val && val <= end; } #define LOG_PROBS_SKIP_IDX 0 #define LOG_PROBS_EMIT_IDX 1 struct Indexer2D { const int& size2_; FORCE_INLINE HOST_AND_DEVICE Indexer2D(const int& size2) : size2_(size2) {} FORCE_INLINE HOST_AND_DEVICE int operator()(int index1, int index2) { return index1 * size2_ + index2; } }; struct Indexer3D { const int& size2_; const int& size3_; FORCE_INLINE HOST_AND_DEVICE Indexer3D(const int& size2, const int& size3) : size2_(size2), size3_(size3) {} FORCE_INLINE HOST_AND_DEVICE int operator()( int index1, int index2, int index3) { return (index1 * size2_ + index2) * size3_ + index3; } }; struct Indexer4D { const int& size2_; const int& size3_; const int& size4_; HOST_AND_DEVICE Indexer4D( const int& size2, const int& size3, const int& size4) : size2_(size2), size3_(size3), size4_(size4) {} HOST_AND_DEVICE int operator()( int index1, int index2, int index3, int index4) { return ((index1 * size2_ + index2) * size3_ + index3) * size4_ + index4; } }; } // namespace rnnt } // namespace torchaudio