torchaudio/csrc/rnnt/gpu/compute.cu (133 lines of code) (raw):

#include <c10/cuda/CUDAStream.h> #include <torch/script.h> #include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h> namespace torchaudio { namespace rnnt { namespace gpu { // Entry point into RNNT Loss std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute( torch::Tensor& logits, const torch::Tensor& targets, const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, double clamp) { TORCH_CHECK( logits.device().type() == targets.device().type(), "logits and targets must be on the same device"); TORCH_CHECK( logits.device().type() == logit_lengths.device().type(), "logits and logit_lengths must be on the same device"); TORCH_CHECK( logits.device().type() == target_lengths.device().type(), "logits and target_lengths must be on the same device"); TORCH_CHECK( logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, "logits must be float32 or float16 (half) type"); TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); TORCH_CHECK( logit_lengths.dtype() == torch::kInt32, "logit_lengths must be int32 type"); TORCH_CHECK( target_lengths.dtype() == torch::kInt32, "target_lengths must be int32 type"); TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); TORCH_CHECK( logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); TORCH_CHECK( target_lengths.is_contiguous(), "target_lengths must be contiguous"); TORCH_CHECK( logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); TORCH_CHECK( targets.dim() == 2, "targets must be 2-D (batch, max target length)"); TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); TORCH_CHECK( logit_lengths.size(0) == logits.size(0), "batch dimension mismatch between logits and logit_lengths"); TORCH_CHECK( target_lengths.size(0) == logits.size(0), "batch dimension mismatch between logits and target_lengths"); TORCH_CHECK( targets.size(0) == logits.size(0), "batch dimension mismatch between logits and targets"); TORCH_CHECK( blank >= 0 && blank < logits.size(-1), "blank must be within [0, logits.shape[-1])"); TORCH_CHECK( logits.size(1) == at::max(logit_lengths).item().toInt(), "input length mismatch"); TORCH_CHECK( logits.size(2) == at::max(target_lengths).item().toInt() + 1, "output length mismatch"); TORCH_CHECK( targets.size(1) == at::max(target_lengths).item().toInt(), "target length mismatch"); Options options; options.batchSize_ = logit_lengths.size(0); options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); options.maxSrcLen_ = logits.size(1); options.maxTgtLen_ = logits.size(2); options.numTargets_ = logits.size(3); options.blank_ = blank; options.clamp_ = clamp; CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); options.stream_ = at::cuda::getCurrentCUDAStream(); cudaSetDevice(logits.get_device()); options.device_ = GPU; torch::Tensor costs = torch::empty( options.batchSize_ * options.nHypos_, torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); c10::optional<torch::Tensor> gradients = torch::zeros_like(logits); torch::Tensor int_workspace = torch::empty( IntWorkspace::ComputeSizeFromOptions(options), torch::TensorOptions() .device(logits.device()) .dtype(torch::ScalarType::Int)); torch::Tensor float_workspace = torch::empty( DtypeWorkspace<float>::ComputeSizeFromOptions(options), torch::TensorOptions() .device(logits.device()) .dtype(torch::ScalarType::Float)); Workspace<float> workspace( /*options=*/options, /*dtype_data=*/float_workspace.data_ptr<float>(), /*dtype_size=*/float_workspace.numel(), /*int_data=*/int_workspace.data_ptr<int>(), /*int_size=*/int_workspace.numel()); switch (logits.scalar_type()) { case torch::ScalarType::Float: { Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>( /*workspace=*/workspace, /*logits=*/logits.data_ptr<float>(), /*targets=*/targets.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(), /*costs=*/costs.data_ptr<float>(), /*gradients=*/gradients->data_ptr<float>()); break; } case torch::ScalarType::Half: { Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>( /*workspace=*/workspace, /*logits=*/logits.data_ptr<c10::Half>(), /*targets=*/targets.data_ptr<int>(), /*logit_lengths=*/logit_lengths.data_ptr<int>(), /*target_lengths=*/target_lengths.data_ptr<int>(), /*costs=*/costs.data_ptr<c10::Half>(), /*gradients=*/gradients->data_ptr<c10::Half>()); break; } default: { break; } }; return std::make_tuple(costs, gradients); } TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { m.impl("rnnt_loss", &compute); } } // namespace gpu } // namespace rnnt } // namespace torchaudio