in torchvision/csrc/io/video_reader/video_reader.cpp [182:423]
torch::List<torch::Tensor> readVideo(
bool isReadFile,
const torch::Tensor& input_video,
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
int64_t videoStartUs, videoEndUs;
offsetsToUs(
seekFrameMargin,
readVideoStream,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen,
videoStartUs,
videoEndUs);
DecoderParameters params = getDecoderParams(
videoStartUs, // videoStartPts
videoEndUs, // videoEndPts
seekFrameMargin, // seekFrameMargin
getPtsOnly, // getPtsOnly
readVideoStream, // readVideoStream
width, // width
height, // height
minDimension, // minDimension
maxDimension, // maxDimension
readAudioStream, // readAudioStream
audioSamples, // audioSamples
audioChannels // audioChannels
);
SyncDecoder decoder;
std::vector<DecoderOutputMessage> audioMessages, videoMessages;
DecoderInCallback callback = nullptr;
std::string logMessage, logType;
if (isReadFile) {
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
} else {
callback = MemoryBuffer::getCallback(
input_video.data_ptr<uint8_t>(), input_video.size(0));
logType = "memory";
logMessage = std::to_string(input_video.size(0));
}
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] has started";
const auto now = std::chrono::system_clock::now();
bool succeeded;
DecoderMetadata audioMetadata, videoMetadata;
std::vector<DecoderMetadata> metadata;
if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
for (const auto& header : metadata) {
if (header.format.type == TYPE_VIDEO) {
videoMetadata = header;
} else if (header.format.type == TYPE_AUDIO) {
audioMetadata = header;
}
}
int res;
DecoderOutputMessage msg;
while (0 == (res = decoder.decode(&msg, decoderTimeoutMs))) {
if (msg.header.format.type == TYPE_VIDEO) {
videoMessages.push_back(std::move(msg));
}
if (msg.header.format.type == TYPE_AUDIO) {
audioMessages.push_back(std::move(msg));
}
msg.payload.reset();
}
} else {
LOG(ERROR) << "Decoder initialization has failed";
}
const auto then = std::chrono::system_clock::now();
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] has finished, "
<< std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count()
<< " us";
decoder.shutdown();
// video section
torch::Tensor videoFrame = torch::zeros({0}, torch::kByte);
torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
if (succeeded && readVideoStream == 1) {
if (!videoMessages.empty()) {
const auto& header = videoMetadata;
const auto& format = header.format.format.video;
int numVideoFrames = videoMessages.size();
int outHeight = format.height;
int outWidth = format.width;
int numChannels = 3; // decoder guarantees the default AV_PIX_FMT_RGB24
size_t expectedWrittenBytes = 0;
if (getPtsOnly == 0) {
videoFrame = torch::zeros(
{numVideoFrames, outHeight, outWidth, numChannels}, torch::kByte);
expectedWrittenBytes =
(size_t)numVideoFrames * outHeight * outWidth * numChannels;
}
videoFramePts = torch::zeros({numVideoFrames}, torch::kLong);
VLOG(2) << "video duration: " << header.duration
<< ", fps: " << header.fps << ", num: " << header.num
<< ", den: " << header.den << ", num frames: " << numVideoFrames;
auto numberWrittenBytes = fillVideoTensor(
videoMessages, videoFrame, videoFramePts, header.num, header.den);
CHECK_EQ(numberWrittenBytes, expectedWrittenBytes);
videoTimeBase = torch::zeros({2}, torch::kInt);
int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
videoTimeBaseData[0] = header.num;
videoTimeBaseData[1] = header.den;
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = header.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
AVRational vr = AVRational{(int)header.num, (int)header.den};
videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, vr);
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] filled video tensors";
} else {
VLOG(1) << "Miss video stream";
}
}
// audio section
torch::Tensor audioFrame = torch::zeros({0}, torch::kFloat);
torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong);
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
if (succeeded && readAudioStream == 1) {
if (!audioMessages.empty()) {
const auto& header = audioMetadata;
const auto& format = header.format.format.audio;
int64_t outAudioChannels = format.channels;
int bytesPerSample =
av_get_bytes_per_sample(static_cast<AVSampleFormat>(format.format));
int numAudioFrames = audioMessages.size();
int64_t numAudioSamples = 0;
if (getPtsOnly == 0) {
int64_t frameSizeTotal = 0;
for (auto const& audioMessage : audioMessages) {
frameSizeTotal += audioMessage.payload->length();
}
CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
numAudioSamples = frameSizeTotal / (outAudioChannels * bytesPerSample);
audioFrame =
torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
}
audioFramePts = torch::zeros({numAudioFrames}, torch::kLong);
VLOG(2) << "audio duration: " << header.duration
<< ", channels: " << format.channels
<< ", sample rate: " << format.samples << ", num: " << header.num
<< ", den: " << header.den;
auto numberWrittenBytes = fillAudioTensor(
audioMessages, audioFrame, audioFramePts, header.num, header.den);
CHECK_EQ(
numberWrittenBytes,
numAudioSamples * outAudioChannels * sizeof(float));
audioTimeBase = torch::zeros({2}, torch::kInt);
int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
audioTimeBaseData[0] = header.num;
audioTimeBaseData[1] = header.den;
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = format.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
AVRational ar = AVRational{(int)header.num, (int)header.den};
audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, ar);
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] filled audio tensors";
} else {
VLOG(1) << "Miss audio stream";
}
}
torch::List<torch::Tensor> result;
result.push_back(std::move(videoFrame));
result.push_back(std::move(videoFramePts));
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioFrame));
result.push_back(std::move(audioFramePts));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
VLOG(1) << "Video decoding from " << logType << " [" << logMessage
<< "] about to return";
return result;
}