torchaudio/csrc/rnnt/autograd.cpp (50 lines of code) (raw):

#include <torch/script.h> #include <torchaudio/csrc/rnnt/compute.h> namespace torchaudio { namespace rnnt { class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> { public: static torch::autograd::tensor_list forward( torch::autograd::AutogradContext* ctx, torch::Tensor& logits, const torch::Tensor& targets, const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, double clamp) { torch::Tensor undef; auto result = rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp); auto costs = std::get<0>(result); auto grads = std::get<1>(result).value_or(undef); ctx->save_for_backward({grads}); return {costs, grads}; } static torch::autograd::tensor_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) { auto saved = ctx->get_saved_variables(); auto grad = saved[0]; auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); auto result = grad * grad_out; torch::Tensor undef; return {result, undef, undef, undef, undef, undef, undef, undef}; } }; std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd( torch::Tensor& logits, const torch::Tensor& targets, const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, double clamp) { at::AutoDispatchBelowADInplaceOrView guard; auto results = RNNTLossFunction::apply( logits, targets, logit_lengths, target_lengths, blank, clamp); return std::make_tuple(results[0], results[1]); } TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { m.impl("rnnt_loss", rnnt_loss_autograd); } } // namespace rnnt } // namespace torchaudio