void ComputeGradientsOneSequence()

in torchaudio/csrc/rnnt/cpu/cpu_kernels.h [296:370]


void ComputeGradientsOneSequence(
    const Options& options,
    TensorView<const DTYPE>& logits,
    const int* targets,
    int srcLen,
    int tgtLen,
    TensorView<const CAST_DTYPE>& denom,
    TensorView<const CAST_DTYPE>& alpha,
    TensorView<const CAST_DTYPE>& beta,
    TensorView<DTYPE>& gradients) {
  // don't set gradients to zero to here as gradients might reuse memory from
  // logits

  const int& T = srcLen;
  const int& U = tgtLen;
  const int& D = options.numTargets_;
  const int& blank = options.blank_;
  const CAST_DTYPE clamp = options.clamp_;

  CAST_DTYPE cost = -beta({0, 0});

  // Note - below gradient is different from numpy_transducer, since we
  // compute log_softmax more efficiently within the loss, to save memory The
  // details of the below implementation / equations can be found in Sec 3.2
  // (function merging) in below paper:
  // https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf

  for (int t = 0; t < T; ++t) {
    for (int u = 0; u < U; ++u) {
      CAST_DTYPE c = alpha({t, u}) + cost - denom({t, u});
      for (int d = 0; d < D; ++d) {
        CAST_DTYPE g = CAST_DTYPE(logits({t, u, d})) + c;
        if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
          gradients({t, u, d}) = std::exp(g + beta({t, u})) - std::exp(g);
        } else if (d == blank && t < T - 1) {
          gradients({t, u, d}) =
              std::exp(g + beta({t, u})) - std::exp(g + beta({t + 1, u}));
        } else if (u < U - 1 && d == targets[u]) {
          gradients({t, u, d}) =
              std::exp(g + beta({t, u})) - std::exp(g + beta({t, u + 1}));
        } else {
          gradients({t, u, d}) = std::exp(g + beta({t, u}));
        }

        if (clamp > 0) {
          gradients({t, u, d}) =
              math::min(CAST_DTYPE(gradients({t, u, d})), clamp);
          gradients({t, u, d}) =
              math::max(CAST_DTYPE(gradients({t, u, d})), -clamp);
        }
      }
    }
  }

  // zero out the rest of the gradients, necessary when reusing logits memory
  // check the memory location to see if it's necessary
  if (&gradients({0, 0, 0}) == &logits({0, 0, 0})) {
    const int& maxT = options.maxSrcLen_;
    const int& maxU = options.maxTgtLen_;
    for (int t = T; t < maxT; ++t) {
      for (int u = 0; u < maxU; ++u) {
        for (int d = 0; d < D; ++d) {
          gradients({t, u, d}) = 0.;
        }
      }
    }
    for (int t = 0; t < T; ++t) {
      for (int u = U; u < maxU; ++u) {
        for (int d = 0; d < D; ++d) {
          gradients({t, u, d}) = 0.;
        }
      }
    }
  }
}