torchaudio/csrc/kaldi.cpp (81 lines of code) (raw):
#include <torch/script.h>
#include "feat/pitch-functions.h"
namespace torchaudio {
namespace kaldi {
namespace {
torch::Tensor denormalize(const torch::Tensor& t) {
auto ret = t;
auto pos = t > 0, neg = t < 0;
ret.index_put({pos}, t.index({pos}) * 32767);
ret.index_put({neg}, t.index({neg}) * 32768);
return ret;
}
torch::Tensor compute_kaldi_pitch(
const torch::Tensor& wave,
const ::kaldi::PitchExtractionOptions& opts) {
::kaldi::VectorBase<::kaldi::BaseFloat> input(wave);
::kaldi::Matrix<::kaldi::BaseFloat> output;
::kaldi::ComputeKaldiPitch(opts, input, &output);
return output.tensor_;
}
} // namespace
torch::Tensor ComputeKaldiPitch(
const torch::Tensor& wave,
double sample_frequency,
double frame_length,
double frame_shift,
double min_f0,
double max_f0,
double soft_min_f0,
double penalty_factor,
double lowpass_cutoff,
double resample_frequency,
double delta_pitch,
double nccf_ballast,
int64_t lowpass_filter_width,
int64_t upsample_filter_width,
int64_t max_frames_latency,
int64_t frames_per_chunk,
bool simulate_first_pass_online,
int64_t recompute_frame,
bool snip_edges) {
TORCH_CHECK(wave.ndimension() == 2, "Input tensor must be 2 dimentional.");
TORCH_CHECK(wave.device().is_cpu(), "Input tensor must be on CPU.");
TORCH_CHECK(
wave.dtype() == torch::kFloat32, "Input tensor must be float32 type.");
::kaldi::PitchExtractionOptions opts;
opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency);
opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift);
opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length);
opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0);
opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0);
opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0);
opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor);
opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff);
opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency);
opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch);
opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width);
opts.upsample_filter_width =
static_cast<::kaldi::int32>(upsample_filter_width);
opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency);
opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk);
opts.simulate_first_pass_online = simulate_first_pass_online;
opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame);
opts.snip_edges = snip_edges;
// Kaldi's float type expects value range of int16 expressed as float
torch::Tensor wave_ = denormalize(wave);
auto batch_size = wave_.size(0);
std::vector<torch::Tensor> results(batch_size);
at::parallel_for(0, batch_size, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; ++i) {
results[i] = compute_kaldi_pitch(wave_.index({i}), opts);
}
});
return torch::stack(results, 0);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::kaldi_ComputeKaldiPitch",
&torchaudio::kaldi::ComputeKaldiPitch);
}
} // namespace kaldi
} // namespace torchaudio