torch::Tensor convert_image_tensor()

in torchaudio/csrc/ffmpeg/buffer.cpp [167:207]


torch::Tensor convert_image_tensor(AVFrame* pFrame) {
  // ref:
  // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
  // https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
  AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format);
  int width = pFrame->width;
  int height = pFrame->height;
  uint8_t* buf = pFrame->data[0];
  int linesize = pFrame->linesize[0];

  int channel;
  switch (format) {
    case AV_PIX_FMT_RGB24:
    case AV_PIX_FMT_BGR24:
      channel = 3;
      break;
    case AV_PIX_FMT_ARGB:
    case AV_PIX_FMT_RGBA:
    case AV_PIX_FMT_ABGR:
    case AV_PIX_FMT_BGRA:
      channel = 4;
      break;
    case AV_PIX_FMT_GRAY8:
      channel = 1;
      break;
    default:
      throw std::runtime_error(
          "Unexpected format: " + std::string(av_get_pix_fmt_name(format)));
  }

  torch::Tensor t;
  t = torch::empty({1, height, width, channel}, torch::kUInt8);
  auto ptr = t.data_ptr<uint8_t>();
  int stride = width * channel;
  for (int i = 0; i < height; ++i) {
    memcpy(ptr, buf, stride);
    buf += linesize;
    ptr += stride;
  }
  return t.permute({0, 3, 1, 2});
}