in torchaudio/csrc/ffmpeg/buffer.cpp [30:101]
torch::Tensor convert_audio_tensor(AVFrame* pFrame) {
// ref: https://ffmpeg.org/doxygen/4.1/filter__audio_8c_source.html#l00215
AVSampleFormat format = static_cast<AVSampleFormat>(pFrame->format);
int num_channels = pFrame->channels;
int bps = av_get_bytes_per_sample(format);
// Note
// FFMpeg's `nb_samples` represnts the number of samples par channel.
// This corresponds to `num_frames` in torchaudio's notation.
// Also torchaudio uses `num_samples` as the number of samples
// across channels.
int num_frames = pFrame->nb_samples;
int is_planar = av_sample_fmt_is_planar(format);
int num_planes = is_planar ? num_channels : 1;
int plane_size = bps * num_frames * (is_planar ? 1 : num_channels);
std::vector<int64_t> shape = is_planar
? std::vector<int64_t>{num_channels, num_frames}
: std::vector<int64_t>{num_frames, num_channels};
torch::Tensor t;
uint8_t* ptr = NULL;
switch (format) {
case AV_SAMPLE_FMT_U8:
case AV_SAMPLE_FMT_U8P: {
t = torch::empty(shape, torch::kUInt8);
ptr = t.data_ptr<uint8_t>();
break;
}
case AV_SAMPLE_FMT_S16:
case AV_SAMPLE_FMT_S16P: {
t = torch::empty(shape, torch::kInt16);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int16_t>());
break;
}
case AV_SAMPLE_FMT_S32:
case AV_SAMPLE_FMT_S32P: {
t = torch::empty(shape, torch::kInt32);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int32_t>());
break;
}
case AV_SAMPLE_FMT_S64:
case AV_SAMPLE_FMT_S64P: {
t = torch::empty(shape, torch::kInt64);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int64_t>());
break;
}
case AV_SAMPLE_FMT_FLT:
case AV_SAMPLE_FMT_FLTP: {
t = torch::empty(shape, torch::kFloat32);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<float>());
break;
}
case AV_SAMPLE_FMT_DBL:
case AV_SAMPLE_FMT_DBLP: {
t = torch::empty(shape, torch::kFloat64);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<double>());
break;
}
default:
throw std::runtime_error(
"Unsupported audio format: " +
std::string(av_get_sample_fmt_name(format)));
}
for (int i = 0; i < num_planes; ++i) {
memcpy(ptr, pFrame->extended_data[i], plane_size);
ptr += plane_size;
}
if (is_planar)
t = t.t();
return t;
}