in torchvision/csrc/io/video_reader/video_reader.cpp [425:566]
torch::List<torch::Tensor> probeVideo(
bool isReadFile,
const torch::Tensor& input_video,
std::string videoPath) {
DecoderParameters params = getDecoderParams(
0, // videoStartUs
-1, // videoEndUs
0, // seekFrameMargin
1, // getPtsOnly
1, // readVideoStream
0, // width
0, // height
0, // minDimension
0, // maxDimension
1, // readAudioStream
0, // audioSamples
0 // audioChannels
);
SyncDecoder decoder;
DecoderInCallback callback = nullptr;
std::string logMessage, logType;
if (isReadFile) {
params.uri = videoPath;
logType = "file";
logMessage = videoPath;
} else {
callback = MemoryBuffer::getCallback(
input_video.data_ptr<uint8_t>(), input_video.size(0));
logType = "memory";
logMessage = std::to_string(input_video.size(0));
}
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] has started";
const auto now = std::chrono::system_clock::now();
bool succeeded;
bool gotAudio = false, gotVideo = false;
DecoderMetadata audioMetadata, videoMetadata;
std::vector<DecoderMetadata> metadata;
if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
for (const auto& header : metadata) {
if (header.format.type == TYPE_VIDEO) {
gotVideo = true;
videoMetadata = header;
} else if (header.format.type == TYPE_AUDIO) {
gotAudio = true;
audioMetadata = header;
}
}
const auto then = std::chrono::system_clock::now();
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] has finished, "
<< std::chrono::duration_cast<std::chrono::microseconds>(then - now)
.count()
<< " us";
} else {
LOG(ERROR) << "Decoder initialization has failed";
}
decoder.shutdown();
// video section
torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
if (succeeded && gotVideo) {
videoTimeBase = torch::zeros({2}, torch::kInt);
int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
const auto& header = videoMetadata;
videoTimeBaseData[0] = header.num;
videoTimeBaseData[1] = header.den;
videoFps = torch::zeros({1}, torch::kFloat);
float* videoFpsData = videoFps.data_ptr<float>();
videoFpsData[0] = header.fps;
videoDuration = torch::zeros({1}, torch::kLong);
int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
AVRational avr = AVRational{(int)header.num, (int)header.den};
videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
VLOG(2) << "Prob fps: " << header.fps << ", duration: " << header.duration
<< ", num: " << header.num << ", den: " << header.den;
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] filled video tensors";
} else {
LOG(ERROR) << "Miss video stream";
}
// audio section
torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
if (succeeded && gotAudio) {
audioTimeBase = torch::zeros({2}, torch::kInt);
int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
const auto& header = audioMetadata;
const auto& media = header.format;
const auto& format = media.format.audio;
audioTimeBaseData[0] = header.num;
audioTimeBaseData[1] = header.den;
audioSampleRate = torch::zeros({1}, torch::kInt);
int* audioSampleRateData = audioSampleRate.data_ptr<int>();
audioSampleRateData[0] = format.samples;
audioDuration = torch::zeros({1}, torch::kLong);
int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
AVRational avr = AVRational{(int)header.num, (int)header.den};
audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
VLOG(2) << "Prob sample rate: " << format.samples
<< ", duration: " << header.duration << ", num: " << header.num
<< ", den: " << header.den;
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] filled audio tensors";
} else {
VLOG(1) << "Miss audio stream";
}
torch::List<torch::Tensor> result;
result.push_back(std::move(videoTimeBase));
result.push_back(std::move(videoFps));
result.push_back(std::move(videoDuration));
result.push_back(std::move(audioTimeBase));
result.push_back(std::move(audioSampleRate));
result.push_back(std::move(audioDuration));
VLOG(1) << "Video probing from " << logType << " [" << logMessage
<< "] is about to return";
return result;
}