torchaudio/csrc/rnnt/compute.cpp (23 lines of code) (raw):
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> (Tensor, Tensor?)");
}