torchaudio/csrc/ffmpeg/buffer.cpp (219 lines of code) (raw):
#include <torchaudio/csrc/ffmpeg/buffer.h>
#include <stdexcept>
#include <vector>
namespace torchaudio {
namespace ffmpeg {
Buffer::Buffer(int frames_per_chunk, int num_chunks)
: frames_per_chunk(frames_per_chunk), num_chunks(num_chunks) {}
AudioBuffer::AudioBuffer(int frames_per_chunk, int num_chunks)
: Buffer(frames_per_chunk, num_chunks) {}
VideoBuffer::VideoBuffer(int frames_per_chunk, int num_chunks)
: Buffer(frames_per_chunk, num_chunks) {}
////////////////////////////////////////////////////////////////////////////////
// Query
////////////////////////////////////////////////////////////////////////////////
bool Buffer::is_ready() const {
if (frames_per_chunk < 0)
return num_buffered_frames > 0;
return num_buffered_frames >= frames_per_chunk;
}
////////////////////////////////////////////////////////////////////////////////
// Modifiers - Push Audio
////////////////////////////////////////////////////////////////////////////////
namespace {
torch::Tensor convert_audio_tensor(AVFrame* pFrame) {
// ref: https://ffmpeg.org/doxygen/4.1/filter__audio_8c_source.html#l00215
AVSampleFormat format = static_cast<AVSampleFormat>(pFrame->format);
int num_channels = pFrame->channels;
int bps = av_get_bytes_per_sample(format);
// Note
// FFMpeg's `nb_samples` represnts the number of samples par channel.
// This corresponds to `num_frames` in torchaudio's notation.
// Also torchaudio uses `num_samples` as the number of samples
// across channels.
int num_frames = pFrame->nb_samples;
int is_planar = av_sample_fmt_is_planar(format);
int num_planes = is_planar ? num_channels : 1;
int plane_size = bps * num_frames * (is_planar ? 1 : num_channels);
std::vector<int64_t> shape = is_planar
? std::vector<int64_t>{num_channels, num_frames}
: std::vector<int64_t>{num_frames, num_channels};
torch::Tensor t;
uint8_t* ptr = NULL;
switch (format) {
case AV_SAMPLE_FMT_U8:
case AV_SAMPLE_FMT_U8P: {
t = torch::empty(shape, torch::kUInt8);
ptr = t.data_ptr<uint8_t>();
break;
}
case AV_SAMPLE_FMT_S16:
case AV_SAMPLE_FMT_S16P: {
t = torch::empty(shape, torch::kInt16);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int16_t>());
break;
}
case AV_SAMPLE_FMT_S32:
case AV_SAMPLE_FMT_S32P: {
t = torch::empty(shape, torch::kInt32);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int32_t>());
break;
}
case AV_SAMPLE_FMT_S64:
case AV_SAMPLE_FMT_S64P: {
t = torch::empty(shape, torch::kInt64);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<int64_t>());
break;
}
case AV_SAMPLE_FMT_FLT:
case AV_SAMPLE_FMT_FLTP: {
t = torch::empty(shape, torch::kFloat32);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<float>());
break;
}
case AV_SAMPLE_FMT_DBL:
case AV_SAMPLE_FMT_DBLP: {
t = torch::empty(shape, torch::kFloat64);
ptr = reinterpret_cast<uint8_t*>(t.data_ptr<double>());
break;
}
default:
throw std::runtime_error(
"Unsupported audio format: " +
std::string(av_get_sample_fmt_name(format)));
}
for (int i = 0; i < num_planes; ++i) {
memcpy(ptr, pFrame->extended_data[i], plane_size);
ptr += plane_size;
}
if (is_planar)
t = t.t();
return t;
}
} // namespace
void AudioBuffer::push_tensor(torch::Tensor t) {
// If frames_per_chunk < 0, users want to fetch all frames.
// Just push back to chunks and that's it.
if (frames_per_chunk < 0) {
chunks.push_back(t);
num_buffered_frames += t.size(0);
return;
}
// Push
// Note:
// For audio, the incoming tensor contains multiple of samples.
// For small `frames_per_chunk` value, it might be more than `max_frames`.
// If we push the tensor as-is, then, the whole frame might be popped at
// trimming stage, resulting buffer always empty. So we slice push the
// incoming Tensor.
// Check the last inserted Tensor and if the numbe of frames is not
// frame_per_chunk, reprocess it again with the incomping tensor
if (num_buffered_frames % frames_per_chunk) {
torch::Tensor prev = chunks.back();
chunks.pop_back();
num_buffered_frames -= prev.size(0);
t = torch::cat({prev, t}, 0);
}
while (true) {
int num_input_frames = t.size(0);
if (num_input_frames <= frames_per_chunk) {
chunks.push_back(t);
num_buffered_frames += num_input_frames;
break;
}
// The input tensor contains more frames than frames_per_chunk
auto splits = torch::tensor_split(t, {frames_per_chunk, num_input_frames});
chunks.push_back(splits[0]);
num_buffered_frames += frames_per_chunk;
t = splits[1];
}
// Trim
// If frames_per_chunk > 0, we only retain the following number of frames and
// Discard older frames.
int max_frames = num_chunks * frames_per_chunk;
while (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
}
}
void AudioBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_audio_tensor(frame));
}
////////////////////////////////////////////////////////////////////////////////
// Modifiers - Push Video
////////////////////////////////////////////////////////////////////////////////
namespace {
torch::Tensor convert_image_tensor(AVFrame* pFrame) {
// ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format);
int width = pFrame->width;
int height = pFrame->height;
uint8_t* buf = pFrame->data[0];
int linesize = pFrame->linesize[0];
int channel;
switch (format) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
channel = 3;
break;
case AV_PIX_FMT_ARGB:
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA:
channel = 4;
break;
case AV_PIX_FMT_GRAY8:
channel = 1;
break;
default:
throw std::runtime_error(
"Unexpected format: " + std::string(av_get_pix_fmt_name(format)));
}
torch::Tensor t;
t = torch::empty({1, height, width, channel}, torch::kUInt8);
auto ptr = t.data_ptr<uint8_t>();
int stride = width * channel;
for (int i = 0; i < height; ++i) {
memcpy(ptr, buf, stride);
buf += linesize;
ptr += stride;
}
return t.permute({0, 3, 1, 2});
}
} // namespace
void VideoBuffer::push_tensor(torch::Tensor t) {
// the video frames is expected to contain only one frame
chunks.push_back(t);
num_buffered_frames += t.size(0);
if (frames_per_chunk < 0) {
return;
}
// Trim
int max_frames = num_chunks * frames_per_chunk;
if (num_buffered_frames > max_frames) {
TORCH_WARN_ONCE(
"The number of buffered frames exceeded the buffer size. "
"Dropping the old frames. "
"To avoid this, you can set a higher buffer_chunk_size value.");
torch::Tensor& t = chunks.front();
num_buffered_frames -= t.size(0);
chunks.pop_front();
}
}
void VideoBuffer::push_frame(AVFrame* frame) {
push_tensor(convert_image_tensor(frame));
}
////////////////////////////////////////////////////////////////////////////////
// Modifiers - Pop
////////////////////////////////////////////////////////////////////////////////
using namespace torch::indexing;
c10::optional<torch::Tensor> Buffer::pop_chunk() {
if (!num_buffered_frames) {
return c10::optional<torch::Tensor>{};
}
if (frames_per_chunk < 0) {
return c10::optional<torch::Tensor>{pop_all()};
}
return c10::optional<torch::Tensor>{pop_one_chunk()};
}
torch::Tensor AudioBuffer::pop_one_chunk() {
// Audio deque are aligned with `frames_per_chunk`
torch::Tensor ret = chunks.front();
chunks.pop_front();
num_buffered_frames -= ret.size(0);
return ret;
}
torch::Tensor VideoBuffer::pop_one_chunk() {
// Video deque contains one frame par one tensor
std::vector<torch::Tensor> ret;
while (num_buffered_frames > 0 && ret.size() < frames_per_chunk) {
torch::Tensor& t = chunks.front();
ret.push_back(t);
chunks.pop_front();
num_buffered_frames -= 1;
}
return torch::cat(ret, 0);
}
torch::Tensor Buffer::pop_all() {
// Note:
// This method is common to audio/video.
// In audio case, each Tensor contains multiple frames
// In video case, each Tensor contains one frame,
std::vector<torch::Tensor> ret;
while (chunks.size()) {
torch::Tensor& t = chunks.front();
int n_frames = t.size(0);
ret.push_back(t);
chunks.pop_front();
num_buffered_frames -= n_frames;
}
return torch::cat(ret, 0);
}
void Buffer::flush() {
chunks.clear();
}
} // namespace ffmpeg
} // namespace torchaudio