torch::List probeVideo()

in torchvision/csrc/io/video_reader/video_reader.cpp [425:566]


torch::List<torch::Tensor> probeVideo(
    bool isReadFile,
    const torch::Tensor& input_video,
    std::string videoPath) {
  DecoderParameters params = getDecoderParams(
      0, // videoStartUs
      -1, // videoEndUs
      0, // seekFrameMargin
      1, // getPtsOnly
      1, // readVideoStream
      0, // width
      0, // height
      0, // minDimension
      0, // maxDimension
      1, // readAudioStream
      0, // audioSamples
      0 // audioChannels
  );

  SyncDecoder decoder;
  DecoderInCallback callback = nullptr;
  std::string logMessage, logType;
  if (isReadFile) {
    params.uri = videoPath;
    logType = "file";
    logMessage = videoPath;
  } else {
    callback = MemoryBuffer::getCallback(
        input_video.data_ptr<uint8_t>(), input_video.size(0));
    logType = "memory";
    logMessage = std::to_string(input_video.size(0));
  }

  VLOG(1) << "Video probing from " << logType << " [" << logMessage
          << "] has started";

  const auto now = std::chrono::system_clock::now();

  bool succeeded;
  bool gotAudio = false, gotVideo = false;
  DecoderMetadata audioMetadata, videoMetadata;
  std::vector<DecoderMetadata> metadata;
  if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
    for (const auto& header : metadata) {
      if (header.format.type == TYPE_VIDEO) {
        gotVideo = true;
        videoMetadata = header;
      } else if (header.format.type == TYPE_AUDIO) {
        gotAudio = true;
        audioMetadata = header;
      }
    }
    const auto then = std::chrono::system_clock::now();
    VLOG(1) << "Video probing from " << logType << " [" << logMessage
            << "] has finished, "
            << std::chrono::duration_cast<std::chrono::microseconds>(then - now)
                   .count()
            << " us";
  } else {
    LOG(ERROR) << "Decoder initialization has failed";
  }

  decoder.shutdown();

  // video section
  torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
  torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
  torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);

  if (succeeded && gotVideo) {
    videoTimeBase = torch::zeros({2}, torch::kInt);
    int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
    const auto& header = videoMetadata;

    videoTimeBaseData[0] = header.num;
    videoTimeBaseData[1] = header.den;

    videoFps = torch::zeros({1}, torch::kFloat);
    float* videoFpsData = videoFps.data_ptr<float>();
    videoFpsData[0] = header.fps;

    videoDuration = torch::zeros({1}, torch::kLong);
    int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
    AVRational avr = AVRational{(int)header.num, (int)header.den};
    videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);

    VLOG(2) << "Prob fps: " << header.fps << ", duration: " << header.duration
            << ", num: " << header.num << ", den: " << header.den;

    VLOG(1) << "Video probing from " << logType << " [" << logMessage
            << "] filled video tensors";
  } else {
    LOG(ERROR) << "Miss video stream";
  }

  // audio section
  torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
  torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
  torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);

  if (succeeded && gotAudio) {
    audioTimeBase = torch::zeros({2}, torch::kInt);
    int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
    const auto& header = audioMetadata;
    const auto& media = header.format;
    const auto& format = media.format.audio;

    audioTimeBaseData[0] = header.num;
    audioTimeBaseData[1] = header.den;

    audioSampleRate = torch::zeros({1}, torch::kInt);
    int* audioSampleRateData = audioSampleRate.data_ptr<int>();
    audioSampleRateData[0] = format.samples;

    audioDuration = torch::zeros({1}, torch::kLong);
    int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
    AVRational avr = AVRational{(int)header.num, (int)header.den};
    audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);

    VLOG(2) << "Prob sample rate: " << format.samples
            << ", duration: " << header.duration << ", num: " << header.num
            << ", den: " << header.den;

    VLOG(1) << "Video probing from " << logType << " [" << logMessage
            << "] filled audio tensors";
  } else {
    VLOG(1) << "Miss audio stream";
  }

  torch::List<torch::Tensor> result;
  result.push_back(std::move(videoTimeBase));
  result.push_back(std::move(videoFps));
  result.push_back(std::move(videoDuration));
  result.push_back(std::move(audioTimeBase));
  result.push_back(std::move(audioSampleRate));
  result.push_back(std::move(audioDuration));

  VLOG(1) << "Video probing from " << logType << " [" << logMessage
          << "] is about to return";

  return result;
}