torchaudio/csrc/rnnt/gpu/kernels.h (92 lines of code) (raw):
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace torchaudio {
namespace rnnt {
template <typename DTYPE, typename CAST_DTYPE>
HOST_AND_DEVICE void ComputeGradientsElement(
int bTgt,
int t,
int u,
int maxSrcLen,
int maxTgtLen,
int numTargets,
int blank,
CAST_DTYPE clamp,
const DTYPE* logits,
const int* targets,
const int* srcLengths,
const int* tgtLengths,
const CAST_DTYPE* denominators,
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
const int bSrc = bTgt / H;
const int T = srcLengths[bSrc];
const int U = tgtLengths[bTgt] + 1;
if (t >= T || u >= U) { // out of boundary.
if (gradients == logits && t < maxT && u < maxU) {
// gradients and logits are pointing to the same memory location
Indexer3D idxr3(maxT, maxU);
int idx_b_t_u_zero = idxr3(bTgt, t, u);
if (idx_b_t_u_zero != -1) {
int start = idx_b_t_u_zero * D;
for (int b_t_u_d = start; b_t_u_d < start + D; ++b_t_u_d) {
gradients[b_t_u_d] = 0;
}
}
}
return;
}
int costIdx = bTgt * maxT * maxU;
CAST_DTYPE cost = -(betas[costIdx]);
Indexer2D idxr2(maxU - 1);
int idx_b_t_u, idx_b_t_up1, idx_b_tp1_u;
Indexer3D idxr3(maxT, maxU);
idx_b_t_u = idxr3(bTgt, t, u);
idx_b_t_up1 = idxr3(bTgt, t, u + 1);
idx_b_tp1_u = idxr3(bTgt, t + 1, u);
if (idx_b_t_u == -1) {
return;
}
if (isinf(cost) || isnan(cost)) {
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
gradients[b_t_u_d] = 0;
}
return;
}
CAST_DTYPE c = alphas[idx_b_t_u] + cost - denominators[idx_b_t_u];
for (int d = 0; d < D; ++d) {
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}
if (clamp > 0) {
auto g = CAST_DTYPE(gradients[b_t_u_d]);
gradients[b_t_u_d] = math::min(g, clamp);
gradients[b_t_u_d] = math::max(g, -clamp);
}
}
}
} // namespace rnnt
} // namespace torchaudio