torchaudio/csrc/rnnt/cpu/cpu_kernels.h (424 lines of code) (raw):
#pragma once
#include <torchaudio/csrc/rnnt/cpu/math.h>
#include <torchaudio/csrc/rnnt/options.h>
#include <torchaudio/csrc/rnnt/types.h>
#include <cstring>
#include <limits>
#include <vector>
namespace torchaudio {
namespace rnnt {
namespace cpu {
template <typename DTYPE>
struct LogProbs {
DTYPE skip_; // blank.
DTYPE emit_; // target.
LogProbs(DTYPE skip, DTYPE emit) : skip_(skip), emit_(emit) {}
DTYPE& skip() {
return skip_;
}
DTYPE& emit() {
return emit_;
}
const DTYPE& skip() const {
return skip_;
}
const DTYPE& emit() const {
return emit_;
}
};
// TensorView: view a block of allocated memory as a tensor.
template <typename DTYPE>
class TensorView {
public:
TensorView(const std::vector<int>& dims, DTYPE* data)
: dims_(dims), data_(data) {
strides_.resize(dims.size());
strides_.back() = 1;
for (int i = dims.size() - 2; i >= 0; --i) {
strides_[i] = strides_[i + 1] * dims[i + 1];
}
}
DTYPE& operator()(const std::vector<int>& indices) {
CHECK_EQ(indices.size(), dims_.size());
int index = indices.back();
for (int i = indices.size() - 2; i >= 0; --i) {
index += indices[i] * strides_[i];
}
return data_[index];
}
void SetZero() {
int size = dims_[0] * strides_[0];
std::memset(data_, 0, sizeof(DTYPE) * size);
}
private:
std::vector<int> dims_;
std::vector<int> strides_;
DTYPE* data_;
};
template <typename DTYPE, typename CAST_DTYPE>
status_t LogSumExp2D(int N, int D, const DTYPE* logits, CAST_DTYPE* outputs) {
for (int i = 0; i < N * D; i += D) {
CAST_DTYPE max = logits[i];
for (int j = 1; j < D; ++j) {
max = std::max(max, CAST_DTYPE(logits[i + j]));
}
CAST_DTYPE sum = 0;
for (int j = 0; j < D; ++j) {
sum = sum + std::exp(CAST_DTYPE(logits[i + j]) - max);
}
outputs[i / D] = max + std::log(sum);
}
return SUCCESS;
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeLogProbsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<LogProbs<CAST_DTYPE>>& logProbs) {
const int& T = srcLen;
const int& U = tgtLen;
const int& blank = options.blank_;
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
if (u < U - 1) {
logProbs({t, u}).emit() =
CAST_DTYPE(logits({t, u, targets[u]})) - denom({t, u});
}
logProbs({t, u}).skip() =
CAST_DTYPE(logits({t, u, blank})) - denom({t, u});
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
status_t ComputeLogProbs(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<LogProbs<CAST_DTYPE>>> seqlogProbs;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seqlogProbs.push_back(TensorView<LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(logProbs) + b * maxT * maxU));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeLogProbsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*logProbs=*/seqlogProbs[b]);
}
return SUCCESS;
}
template <typename DTYPE>
DTYPE ComputeAlphaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& alpha) {
const int& T = srcLen;
const int& U = tgtLen;
alpha({0, 0}) = DTYPE(0);
for (int t = 1; t < T; ++t) { // u == 0.
alpha({t, 0}) = alpha({t - 1, 0}) + logProbs({t - 1, 0}).skip();
}
for (int u = 1; u < U; ++u) { // t == 0.
alpha({0, u}) = alpha({0, u - 1}) + logProbs({0, u - 1}).emit();
}
for (int t = 1; t < T; ++t) {
for (int u = 1; u < U; ++u) {
alpha({t, u}) = math::lse(
alpha({t - 1, u}) + logProbs({t - 1, u}).skip(),
alpha({t, u - 1}) + logProbs({t, u - 1}).emit());
}
}
DTYPE forward_score = alpha({T - 1, U - 1}) + logProbs({T - 1, U - 1}).skip();
return forward_score;
}
template <typename DTYPE>
DTYPE ComputeBetaOneSequence(
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& beta) {
const int& T = srcLen;
const int& U = tgtLen;
beta({T - 1, U - 1}) = logProbs({T - 1, U - 1}).skip();
for (int t = T - 2; t >= 0; --t) { // u == U - 1.
beta({t, U - 1}) = beta({t + 1, U - 1}) + logProbs({t, U - 1}).skip();
}
for (int u = U - 2; u >= 0; --u) { // t == T - 1.
beta({T - 1, u}) = beta({T - 1, u + 1}) + logProbs({T - 1, u}).emit();
}
for (int t = T - 2; t >= 0; --t) {
for (int u = U - 2; u >= 0; --u) {
beta({t, u}) = math::lse(
beta({t + 1, u}) + logProbs({t, u}).skip(),
beta({t, u + 1}) + logProbs({t, u}).emit());
}
}
DTYPE backward_score = beta({0, 0});
return backward_score;
}
template <typename DTYPE>
DTYPE ComputeAlphaOrBetaOneSequence(
int thread,
const Options& options,
TensorView<const LogProbs<DTYPE>>& logProbs,
int srcLen,
int tgtLen,
TensorView<DTYPE>& alpha,
TensorView<DTYPE>& beta) {
if (thread & 1) {
return ComputeAlphaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs,
/*srcLen=*/srcLen,
/*tgtLen=*/tgtLen,
/*alpha=*/alpha);
} else {
return ComputeBetaOneSequence<DTYPE>(
/*options=*/options,
/*logProbs=*/logProbs,
/*srcLen=*/srcLen,
/*tgtLen=*/tgtLen,
/*beta=*/beta);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeAlphasBetas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* alphas,
CAST_DTYPE* betas,
DTYPE* costs) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_alphas;
std::vector<TensorView<CAST_DTYPE>> seq_betas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_alphas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
seq_betas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
}
std::vector<CAST_DTYPE> scores(B << 1);
//#pragma omp parallel for
for (int t = 0; t < (B << 1); ++t) { // use max 2 * B threads.
int i = (t >> 1);
scores[t] = ComputeAlphaOrBetaOneSequence<CAST_DTYPE>(
/*thread=*/t,
/*options=*/options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*alpha=*/seq_alphas[i],
/*beta=*/seq_betas[i]);
}
for (int b = 0; b < B; ++b) {
costs[b] = -scores[b << 1];
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeGradientsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<const CAST_DTYPE>& alpha,
TensorView<const CAST_DTYPE>& beta,
TensorView<DTYPE>& gradients) {
// don't set gradients to zero to here as gradients might reuse memory from
// logits
const int& T = srcLen;
const int& U = tgtLen;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
CAST_DTYPE cost = -beta({0, 0});
// Note - below gradient is different from numpy_transducer, since we
// compute log_softmax more efficiently within the loss, to save memory The
// details of the below implementation / equations can be found in Sec 3.2
// (function merging) in below paper:
// https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u});
for (int d = 0; d < D; ++d) {
CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g);
} else if (d == blank && t < T - 1) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u}));
} else if (u < U - 1 && d == targets[u]) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1}));
} else {
gradients({t, u, d}) = std::exp(g + beta({t, u}));
}
if (clamp > 0) {
gradients({t, u, d}) =
math::min(CAST_DTYPE(gradients({t, u, d})), clamp);
gradients({t, u, d}) =
math::max(CAST_DTYPE(gradients({t, u, d})), -clamp);
}
}
}
}
// zero out the rest of the gradients, necessary when reusing logits memory
// check the memory location to see if it's necessary
if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) {
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int t = T; t < maxT; ++t) {
for (int u = 0; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
for (int t = 0; t < T; ++t) {
for (int u = U; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeGradients(
const Options& options,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients) {
std::vector<TensorView<const DTYPE>> seqLogits;
std::vector<const int*> seqTargets;
std::vector<TensorView<const CAST_DTYPE>> seqDenoms;
std::vector<TensorView<const CAST_DTYPE>> seq_alphas;
std::vector<TensorView<const CAST_DTYPE>> seq_betas;
std::vector<TensorView<DTYPE>> seq_gradients;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
const int& D = options.numTargets_;
for (int b = 0; b < B; ++b) {
seqLogits.push_back(
TensorView<const DTYPE>({maxT, maxU, D}, logits + b * maxT * maxU * D));
seqTargets.push_back(targets + b * (maxU - 1));
seqDenoms.push_back(TensorView<const CAST_DTYPE>(
{maxT, maxU}, denominators + b * maxT * maxU));
seq_alphas.push_back(
TensorView<const CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
seq_betas.push_back(
TensorView<const CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
seq_gradients.push_back(
TensorView<DTYPE>({maxT, maxU, D}, gradients + b * maxT * maxU * D));
}
//#pragma omp parallel for
for (int b = 0; b < B; ++b) { // use max 2 * B threads.
ComputeGradientsOneSequence<DTYPE, CAST_DTYPE>(
/*options=*/options,
/*logits=*/seqLogits[b],
/*targets=*/seqTargets[b],
/*srcLen=*/srcLengths[b],
/*tgtLen=*/tgtLengths[b] + 1, // with prepended blank.
/*denom=*/seqDenoms[b],
/*alpha=*/seq_alphas[b],
/*beta=*/seq_betas[b],
/*gradients=*/seq_gradients[b]);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeAlphas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* alphas) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_alphas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_alphas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, alphas + b * maxT * maxU));
}
//#pragma omp parallel for
for (int i = 0; i < B; ++i) { // use max 2 * B threads.
ComputeAlphaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*alpha=*/seq_alphas[i]);
}
}
template <typename DTYPE, typename CAST_DTYPE>
void ComputeBetas(
const Options& options,
const CAST_DTYPE* logProbs,
const int* srcLengths,
const int* tgtLengths,
CAST_DTYPE* costs,
CAST_DTYPE* betas) {
std::vector<TensorView<const LogProbs<CAST_DTYPE>>> seqlogProbs;
std::vector<TensorView<CAST_DTYPE>> seq_betas;
const int& B = options.batchSize_;
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int b = 0; b < B; ++b) {
seqlogProbs.push_back(TensorView<const LogProbs<CAST_DTYPE>>(
{maxT, maxU},
reinterpret_cast<LogProbs<CAST_DTYPE>*>(
const_cast<CAST_DTYPE*>(logProbs)) +
b * maxT * maxU));
seq_betas.push_back(
TensorView<CAST_DTYPE>({maxT, maxU}, betas + b * maxT * maxU));
}
//#pragma omp parallel for
for (int i = 0; i < B; ++i) {
ComputeBetaOneSequence<DTYPE>(
options,
/*logProbs=*/seqlogProbs[i],
/*srcLen=*/srcLengths[i],
/*tgtLen=*/tgtLengths[i] + 1, // with prepended blank.
/*betas=*/seq_betas[i]);
}
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio