torchaudio/csrc/rnnt/cpu/cpu_transducer.h (149 lines of code) (raw):

#pragma once #include <torchaudio/csrc/rnnt/cpu/cpu_kernels.h> #include <torchaudio/csrc/rnnt/workspace.h> namespace torchaudio { namespace rnnt { namespace cpu { // Inputs: // workspace: workspace. // logits: pointer to (B, maxT, maxU, D) logits. // targets: pointer to (B, maxU - 1) targets in the batch. // srcLengths: pointer to (B, ) source lengths in the batch. // tgtLengths: pointer to (B, ) target lengths in the batch. // // Outputs: // costs: pointer to (B, ) costs in the batch. // gradients: pointer to (B, maxT, maxU, D) gradients in the batch. template <typename DTYPE, typename CAST_DTYPE> status_t Compute( const Workspace<CAST_DTYPE>& workspace, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, DTYPE* costs, DTYPE* gradients = nullptr) { const Options& options = workspace.GetOptions(); CHECK_EQ(options.device_, CPU); const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; const int& D = options.numTargets_; { // compute denominators. LogSumExp2D<DTYPE, CAST_DTYPE>( /*N=*/B * maxT * maxU, /*D=*/D, /*logits=*/logits, /*denominators=*/workspace.GetPointerToDenominators()); } { // compute log prob pairs. ComputeLogProbs<DTYPE, CAST_DTYPE>( /*options=*/options, /*logits=*/logits, /*targets=*/targets, /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*denominators=*/workspace.GetPointerToDenominators(), /*log_probs=*/workspace.GetPointerToLogProbs()); } { // compute alphas and betas. ComputeAlphasBetas<DTYPE, CAST_DTYPE>( /*options=*/options, /*log_probs=*/workspace.GetPointerToLogProbs(), /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*alphas=*/workspace.GetPointerToAlphas(), /*betas=*/workspace.GetPointerToBetas(), /*costs=*/costs); } if (gradients != nullptr) { ComputeGradients<DTYPE, CAST_DTYPE>( /*options=*/options, /*logits=*/logits, /*targets=*/targets, /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*denominators=*/workspace.GetPointerToDenominators(), /*alphas=*/workspace.GetPointerToAlphas(), /*betas=*/workspace.GetPointerToBetas(), /*gradients=*/gradients); } return SUCCESS; } template <typename DTYPE, typename CAST_DTYPE> status_t ComputeAlphas( const Workspace<CAST_DTYPE>& workspace, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, DTYPE* alphas) { const Options& options = workspace.GetOptions(); CHECK_EQ(options.device_, CPU); const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; const int& D = options.numTargets_; { // compute denominators. LogSumExp2D<DTYPE, CAST_DTYPE>( /*N=*/B * maxT * maxU, /*D=*/D, /*logits=*/logits, /*denominators=*/workspace.GetPointerToDenominators()); } { // compute log prob pairs. ComputeLogProbs<DTYPE, CAST_DTYPE>( /*options=*/options, /*logits=*/logits, /*targets=*/targets, /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*denominators=*/workspace.GetPointerToDenominators(), /*log_probs=*/workspace.GetPointerToLogProbs()); } { // compute alphas. ComputeAlphas<DTYPE, CAST_DTYPE>( /*options=*/options, /*log_probs=*/workspace.GetPointerToLogProbs(), /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*alphas=*/alphas); } return SUCCESS; } template <typename DTYPE, typename CAST_DTYPE> status_t ComputeBetas( const Workspace<CAST_DTYPE>& workspace, const DTYPE* logits, const int* targets, const int* srcLengths, const int* tgtLengths, DTYPE* costs, DTYPE* betas) { const Options& options = workspace.GetOptions(); CHECK_EQ(options.device_, CPU); const int& B = options.batchSize_; const int& maxT = options.maxSrcLen_; const int& maxU = options.maxTgtLen_; const int& D = options.numTargets_; { // compute denominators. LogSumExp2D<DTYPE, CAST_DTYPE>( /*N=*/B * maxT * maxU, /*D=*/D, /*logits=*/logits, /*denominators=*/workspace.GetPointerToDenominators()); } { // compute log prob pairs. ComputeLogProbs<DTYPE, CAST_DTYPE>( /*options=*/options, /*logits=*/logits, /*targets=*/targets, /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*denominators=*/workspace.GetPointerToDenominators(), /*log_probs=*/workspace.GetPointerToLogProbs()); } { // compute betas. ComputeBetas<DTYPE, CAST_DTYPE>( /*options=*/options, /*log_probs=*/workspace.GetPointerToLogProbs(), /*srcLengths=*/srcLengths, /*tgtLengths=*/tgtLengths, /*costs=*/costs, /*betas=*/betas); } return SUCCESS; } } // namespace cpu } // namespace rnnt } // namespace torchaudio