in torchaudio/csrc/rnnt/cpu/cpu_kernels.h [296:370]
void ComputeGradientsOneSequence(
const Options& options,
TensorView<const DTYPE>& logits,
const int* targets,
int srcLen,
int tgtLen,
TensorView<const CAST_DTYPE>& denom,
TensorView<const CAST_DTYPE>& alpha,
TensorView<const CAST_DTYPE>& beta,
TensorView<DTYPE>& gradients) {
// don't set gradients to zero to here as gradients might reuse memory from
// logits
const int& T = srcLen;
const int& U = tgtLen;
const int& D = options.numTargets_;
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;
CAST_DTYPE cost = -beta({0, 0});
// Note - below gradient is different from numpy_transducer, since we
// compute log_softmax more efficiently within the loss, to save memory The
// details of the below implementation / equations can be found in Sec 3.2
// (function merging) in below paper:
// https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf
for (int t = 0; t < T; ++t) {
for (int u = 0; u < U; ++u) {
CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u});
for (int d = 0; d < D; ++d) {
CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g);
} else if (d == blank && t < T - 1) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u}));
} else if (u < U - 1 && d == targets[u]) {
gradients({t, u, d}) =
std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1}));
} else {
gradients({t, u, d}) = std::exp(g + beta({t, u}));
}
if (clamp > 0) {
gradients({t, u, d}) =
math::min(CAST_DTYPE(gradients({t, u, d})), clamp);
gradients({t, u, d}) =
math::max(CAST_DTYPE(gradients({t, u, d})), -clamp);
}
}
}
}
// zero out the rest of the gradients, necessary when reusing logits memory
// check the memory location to see if it's necessary
if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) {
const int& maxT = options.maxSrcLen_;
const int& maxU = options.maxTgtLen_;
for (int t = T; t < maxT; ++t) {
for (int u = 0; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
for (int t = 0; t < T; ++t) {
for (int u = U; u < maxU; ++u) {
for (int d = 0; d < D; ++d) {
gradients({t, u, d}) = 0.;
}
}
}
}
}