torchaudio/csrc/rnnt/cpu/compute_alphas.cpp (58 lines of code) (raw):
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace torchaudio {
namespace rnnt {
namespace cpu {
torch::Tensor compute_alphas(
const torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp) {
Options options;
options.batchSize_ = logit_lengths.size(0);
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
options.maxSrcLen_ = logits.size(1);
options.maxTgtLen_ = logits.size(2);
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
torch::Tensor alphas = torch::zeros(
{options.batchSize_ * options.nHypos_,
options.maxSrcLen_,
options.maxTgtLen_},
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_size=*/int_workspace.numel());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
/*target_lengths=*/target_lengths.data_ptr<int>(),
/*alphas=*/alphas.data_ptr<float>());
return alphas;
}
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss_alphas", &compute_alphas);
}
} // namespace cpu
} // namespace rnnt
} // namespace torchaudio