torchaudio/csrc/rnnt/options.h (50 lines of code) (raw):
#pragma once
//#include <iostream>
#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif // USE_CUDA
#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
namespace torchaudio {
namespace rnnt {
typedef struct Options {
// the device to compute transducer loss.
device_t device_;
#ifdef USE_CUDA
// the stream to launch kernels in when using GPU.
cudaStream_t stream_;
#endif
// The maximum number of threads that can be used.
int numThreads_;
// the index for "blank".
int blank_;
// whether to backtrack the best path.
bool backtrack_;
// gradient clamp value.
float clamp_;
// batch size = B.
int batchSize_;
// Number of hypos per sample = H
int nHypos_;
// the maximum length of src encodings = max_T.
int maxSrcLen_;
// the maximum length of tgt encodings = max_U.
int maxTgtLen_;
// num_targets = D.
int numTargets_;
Options()
: device_(UNDEFINED),
numThreads_(0),
blank_(-1),
backtrack_(false),
clamp_(-1), // negative for disabling clamping by default.
batchSize_(0),
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0) {}
int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
}
int BTU() const {
return batchSize_ * maxSrcLen_ * maxTgtLen_ * nHypos_;
}
friend std::ostream& operator<<(std::ostream& os, const Options& options) {
os << "Options("
<< "batchSize_=" << options.batchSize_ << ", "
<< "maxSrcLen_=" << options.maxSrcLen_ << ", "
<< "maxTgtLen_=" << options.maxTgtLen_ << ", "
<< "numTargets_=" << options.numTargets_ << ")";
return os;
}
} Options;
} // namespace rnnt
} // namespace torchaudio