torchaudio/csrc/sox/io.cpp (127 lines of code) (raw):
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/effects_chain.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/types.h>
#include <torchaudio/csrc/sox/utils.h>
using namespace torch::indexing;
using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_io {
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path,
const c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::vector<std::vector<std::string>> get_effects(
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}
std::tuple<torch::Tensor, int64_t> load_audio_file(
const std::string& path,
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
const c10::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}
void save_audio_file(
const std::string& path,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
const auto filetype = [&]() {
if (format.has_value())
return format.value();
return get_filetype(path);
}();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "amr-nb format only supports single channel audio.");
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "htk format only supports single channel audio.");
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
TORCH_CHECK(
num_channels == 1, "gsm format only supports single channel audio.");
TORCH_CHECK(
sample_rate == 8000,
"gsm format only supports a sampling rate of 8kHz.");
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
SoxFormat sf(sox_open_write(
path.c_str(),
&signal_info,
&encoding_info,
/*filetype=*/filetype.c_str(),
/*oob=*/nullptr,
/*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
}
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf);
chain.run();
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file",
&torchaudio::sox_io::load_audio_file);
m.def(
"torchaudio::sox_io_save_audio_file",
&torchaudio::sox_io::save_audio_file);
}
} // namespace sox_io
} // namespace torchaudio