torchaudio/csrc/rnnt/workspace.h (184 lines of code) (raw):
#pragma once
#include <cstring>
#include <vector>
#include <torchaudio/csrc/rnnt/options.h>
namespace torchaudio {
namespace rnnt {
// Since CUDA has strict memory alignment, it's better to keep allocated memory
// blocks separate for different data types.
// DtypeWorkspace holds a "view" of workspace for:
// 1. softmax denominators (in log form), size = B * max_T * max_U
// 2. log probibility pairs for blank and target, size = B * max_T * max_U
// 3. alphas, size = B * max_T * max_U
// 4. betas, size = B * max_T * max_U
template <typename DTYPE>
class DtypeWorkspace {
public:
DtypeWorkspace() : options_(), size_(0), data_(nullptr) {}
DtypeWorkspace(const Options& options, DTYPE* data, int size)
: DtypeWorkspace() {
Reset(options, data, size);
}
~DtypeWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
CHECK_NE(options.device_, UNDEFINED);
return ComputeSizeForDenominators(options) +
ComputeSizeForLogProbs(options) + ComputeSizeForAlphas(options) +
ComputeSizeForBetas(options);
}
void Free();
void Reset(const Options& options, DTYPE* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
}
int Size() const {
return size_;
}
DTYPE* GetPointerToDenominators() const {
return data_;
}
DTYPE* GetPointerToLogProbs() const {
return GetPointerToDenominators() + ComputeSizeForDenominators(options_);
}
DTYPE* GetPointerToAlphas() const {
return GetPointerToLogProbs() + ComputeSizeForLogProbs(options_);
}
DTYPE* GetPointerToBetas() const {
return GetPointerToAlphas() + ComputeSizeForAlphas(options_);
}
private:
static int ComputeSizeForDenominators(const Options& options) { // B * T * U
return options.BTU();
}
static int ComputeSizeForLogProbs(const Options& options) { // B * T * U * 2
return options.BTU() * 2;
}
static int ComputeSizeForAlphas(const Options& options) { // B * T * U
return options.BTU();
}
static int ComputeSizeForBetas(const Options& options) { // B * T * U
return options.BTU();
}
Options options_;
int size_; // number of elements in allocated memory.
DTYPE* data_; // pointer to the allocated memory.
};
// IntWorkspace holds a "view" of workspace for:
// 1. alpha counters, size = B * max_U
// 2. beta counters, size = B * max_U
class IntWorkspace {
public:
IntWorkspace() : options_(), size_(0), data_(nullptr) {}
IntWorkspace(const Options& options, int* data, int size) : IntWorkspace() {
Reset(options, data, size);
}
~IntWorkspace() {}
static int ComputeSizeFromOptions(const Options& options) {
return ComputeSizeForAlphaCounters(options) +
ComputeSizeForBetaCounters(options);
}
void Reset(const Options& options, int* data, int size) {
int needed_size = ComputeSizeFromOptions(options);
CHECK_LE(needed_size, size);
options_ = options;
data_ = data;
size_ = size;
ResetAlphaBetaCounters();
}
int Size() const {
return size_;
}
int* GetPointerToAlphaCounters() const {
CHECK_EQ(options_.device_, GPU);
return data_;
}
int* GetPointerToBetaCounters() const {
CHECK_EQ(options_.device_, GPU);
return GetPointerToAlphaCounters() + ComputeSizeForAlphaCounters(options_);
}
private:
inline void ResetAlphaBetaCounters() {
#ifdef USE_CUDA
if (data_ != nullptr && options_.device_ == GPU) {
cudaMemset(
GetPointerToAlphaCounters(),
0,
ComputeSizeForAlphaCounters(options_) * sizeof(int));
cudaMemset(
GetPointerToBetaCounters(),
0,
ComputeSizeForBetaCounters(options_) * sizeof(int));
}
#endif // USE_CUDA
}
static int ComputeSizeForAlphaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
if (options.device_ == GPU) {
return options.BU();
} else {
return 0;
}
#else
return 0;
#endif // USE_CUDA
}
static int ComputeSizeForBetaCounters(const Options& options) { // B * U
#ifdef USE_CUDA
if (options.device_ == GPU) {
return options.BU();
} else {
return 0;
}
#else
return 0;
#endif // USE_CUDA
}
Options options_;
int size_; // number of elements in allocated memory.
int* data_; // pointer to the allocated memory.
};
// Workspace<DTYPE> holds:
// 1. DtypeWorkspace<DTYPE>
// 2. IntWorkspace
template <typename DTYPE>
class Workspace {
public:
Workspace() : options_(), dtype_workspace_(), int_workspace_() {}
Workspace(
const Options& options,
DTYPE* dtype_data,
int dtype_size,
int* int_data,
int int_size)
: Workspace() {
Reset(options, dtype_data, dtype_size, int_data, int_size);
}
~Workspace() {}
void Reset(
const Options& options,
DTYPE* dtype_data,
int dtype_size,
int* int_data,
int int_size) {
options_ = options;
dtype_workspace_.Reset(options_, dtype_data, dtype_size);
int_workspace_.Reset(options_, int_data, int_size);
}
const Options& GetOptions() const {
return options_;
}
DTYPE* GetPointerToDenominators() const {
return dtype_workspace_.GetPointerToDenominators();
}
DTYPE* GetPointerToLogProbs() const {
return dtype_workspace_.GetPointerToLogProbs();
}
DTYPE* GetPointerToAlphas() const {
return dtype_workspace_.GetPointerToAlphas();
}
DTYPE* GetPointerToBetas() const {
return dtype_workspace_.GetPointerToBetas();
}
int* GetPointerToAlphaCounters() const {
return int_workspace_.GetPointerToAlphaCounters();
}
int* GetPointerToBetaCounters() const {
return int_workspace_.GetPointerToBetaCounters();
}
private:
Options options_;
DtypeWorkspace<DTYPE> dtype_workspace_;
IntWorkspace int_workspace_;
};
} // namespace rnnt
} // namespace torchaudio