bool AVInputOp::GetClipsAndLabelsFromDBValue()

in ops/av_input_op.h [425:555]


bool AVInputOp<Context>::GetClipsAndLabelsFromDBValue(
    const std::string& value,
    int& height,
    int& width,
    std::vector<unsigned char*>& buffer_rgb,
    std::vector<float>& buffer_logmels,
    int* label_data,
    int64_t* video_id_data,
    int& number_of_frames,
    int& clip_start_frame,
    std::mt19937* randgen
) {
  try {
    TensorProtos protos;
    int curr_proto_idx = 0;
    CAFFE_ENFORCE(protos.ParseFromString(value));
    const TensorProto& video_proto = protos.protos(curr_proto_idx++);
    const TensorProto& label_proto = protos.protos(curr_proto_idx++);

    int start_frm = 0;
    // start_frm is only valid when sampling 1 clip per video without
    // temporal jitterring
    if (decode_type_ == DecodeType::USE_START_FRM) {
      CAFFE_ENFORCE_GE(
          protos.protos_size(),
          curr_proto_idx + 1,
          "Start frm proto not provided");
      const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
      start_frm = start_frm_proto.int32_data(0);
    }

    if (get_video_id_) {
      CAFFE_ENFORCE_GE(
          protos.protos_size(), curr_proto_idx + 1, "Video Id not provided");
      const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
      for (int i = 0; i < clip_per_video_; i++) {
        video_id_data[i] = video_id_proto.int64_data(0);
      }
    }

    // assign labels
    if (!do_multi_label_) {
      for (int i = 0; i < clip_per_video_; i++) {
        label_data[i] = label_proto.int32_data(0);
      }
    } else {
      // For multiple label case, output label is a binary vector
      // where presented concepts are makred 1
      memset(label_data, 0, sizeof(int) * num_of_class_ * clip_per_video_);
      for (int i = 0; i < clip_per_video_; i++) {
        for (int j = 0; j < label_proto.int32_data_size(); j++) {
          CAFFE_ENFORCE_LT(
              label_proto.int32_data(j),
              num_of_class_,
              "Label should be less than the number of classes.");
          label_data[i * num_of_class_ + label_proto.int32_data(j)] = 1;
        }
      }
    }

    if (use_local_file_) {
      CAFFE_ENFORCE_EQ(
          video_proto.data_type(),
          TensorProto::STRING,
          "Database with a file_list is expected to be string data");
    }

    // initializing the decoding params
    Params params;
    params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
    params.video_res_type_ = video_res_type_;
    params.crop_size_ = crop_size_;
    params.short_edge_ = short_edge_;
    params.outputWidth_ = scale_w_;
    params.outputHeight_ = scale_h_;
    params.decode_type_ = decode_type_;
    params.num_of_required_frame_ = num_of_required_frame_;
    params.getAudio_ = get_logmels_;
    params.getVideo_ = get_rgb_;
    params.outrate_ = logMelAudioSamplingRate_;

    if (jitter_scales_.size() > 0) {
      int select_idx =
        std::uniform_int_distribution<>(0, jitter_scales_.size() - 1)(*randgen);
      params.short_edge_ = jitter_scales_[select_idx];
    }

    char* video_buffer = nullptr; // for decoding from buffer
    std::string video_filename; // for decoding from file
    int encoded_size = 0;
    if (video_proto.data_type() == TensorProto::STRING) {
      const string& encoded_video_str = video_proto.string_data(0);
      if (!use_local_file_) {
        encoded_size = encoded_video_str.size();
        video_buffer = const_cast<char*>(encoded_video_str.data());
      } else {
        video_filename = encoded_video_str;
      }
    } else if (video_proto.data_type() == TensorProto::BYTE) {
      if (!use_local_file_) {
        encoded_size = video_proto.byte_data().size();
        video_buffer = const_cast<char*>(video_proto.byte_data().data());
      } else {
        // TODO: does this works?
        video_filename = video_proto.string_data(0);
      }
    } else {
      CAFFE_ENFORCE(false, "Unknown video data type.");
    }

    DecodeMultipleAVClipsFromVideo(
        video_buffer,
        video_filename,
        encoded_size,
        params,
        start_frm,
        clip_per_video_,
        use_local_file_,
        height,
        width,
        buffer_rgb,
        buffer_logmels,
        number_of_frames,
        clip_start_frame
    );
  } catch (const std::exception& exc) {
    std::cerr << "While calling GetClipsAndLabelsFromDBValue()\n";
    std::cerr << exc.what();
  }
  return true;
}