bool CustomizedVideoInputOp::Prefetch()

in caffe2_customized_ops/video/customized_video_input_op.h [491:622]


bool CustomizedVideoInputOp<Context>::Prefetch() {
  // We will get the reader pointer from input.
  // If we use local clips, db will store the list
  reader_ = &OperatorBase::Input<db::DBReader>(0);

  const int channels = 3;

  // Call mutable_data() once to allocate the underlying memory.
  prefetched_clip_.mutable_data<float>();
  prefetched_label_.mutable_data<int>();

  // Prefetching handled with a thread pool of "decode_threads" threads.
  std::mt19937 meta_randgen(time(nullptr));
  std::vector<std::mt19937> randgen_per_thread;
  for (int i = 0; i < num_decode_threads_; ++i) {
    randgen_per_thread.emplace_back(meta_randgen());
  }

  std::bernoulli_distribution mirror_this_clip(0.5);

  const int num_items = batch_size_;

  // ------------ only useful for crop_ <= 0
  std::vector<float*> list_clip_data;
  std::vector<int> list_height_out;
  std::vector<int> list_width_out;
  list_clip_data.resize(num_items);
  list_height_out.resize(num_items);
  list_width_out.resize(num_items);
  const int MAX_IMAGE_SIZE = 500 * 500;
  if (crop_ <= 0) {
    for (int item_id = 0; item_id < num_items; ++item_id) {
      const int num_clips = 1;
      /*
      we have to allocate outside of DecodeAndTransform,
      because DecodeAndTransform does not change the values.
      */
      list_clip_data[item_id] =
        new float[num_clips * MAX_IMAGE_SIZE * length_ * 3];
      list_height_out[item_id] = -1;
      list_width_out[item_id] = -1;
    } // for
  } //if
  // ------------------------

  for (int item_id = 0; item_id < num_items; ++item_id) {
    std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];

    // get the label data pointer for the item_id -th example
    int* label_data = prefetched_label_.mutable_data<int>() +
        (multiple_label_ ? num_of_labels_ : 1) * item_id;

    // float* clip_data = prefetched_clip_.mutable_data<float>() +
    //   crop_ * crop_ * length_ * channels * item_id;

    std::string key, value;
    // read data
    reader_->Read(&key, &value);
    thread_pool_->runTask(std::bind(
        &CustomizedVideoInputOp<Context>::DecodeAndTransform,
        this,
        std::string(value),
        (crop_ > 0) ?
        (prefetched_clip_.mutable_data<float>() +
          crop_ * crop_ * length_ * channels * item_id) // clip_data
        : (list_clip_data[item_id]), // temp list
        label_data,
        crop_,
        mirror_,
        mean_,
        std_,
        randgen,
        &mirror_this_clip,
        &(list_height_out[item_id]),
        &(list_width_out[item_id])
      ));
  } // for over the batch
  thread_pool_->waitWorkComplete();

  // ------------ only useful for crop_ <= 0
  if (crop_ <= 0) {  // There should be only one item
    if (num_items != 1) {
      LOG(FATAL) << "There should be only one item.";
    }
    if (MAX_IMAGE_SIZE < list_height_out[0] * list_width_out[0]) {
      LOG(FATAL) << "Buffer is too small.";
    }

    // reallocate
    /*
    The network is usually designed for 224x224 input. If the empty image is
    smaller than this size, the network run can crash (e.g., kernel > space)
    */
    const int MIN_SIZE = 224;
    vector<TIndex> data_shape(5);
    data_shape[0] = batch_size_;
    data_shape[1] = 3;
    data_shape[2] = length_;
    data_shape[3] = std::max(list_height_out[0], MIN_SIZE); // for safety
    data_shape[4] = std::max(list_width_out[0], MIN_SIZE); // for safety
    prefetched_clip_.Resize(data_shape);
    prefetched_clip_.mutable_data<float>();
    if (list_height_out[0] < MIN_SIZE || list_width_out[0] < MIN_SIZE) {
      LOG(ERROR) << "Video is too small.";
    }

    // in case of empty video, initialize an all-zero blob
    memset(prefetched_clip_.mutable_data<float>(), 0,
      sizeof(float) * prefetched_clip_.size());
    if (list_clip_data[0] != nullptr
        && list_height_out[0] > 0 && list_width_out[0] > 0) {
      const int num_clips = batch_size_;
      memcpy(
          prefetched_clip_.mutable_data<float>(),
          list_clip_data[0],
          sizeof(float) * num_clips *
          list_height_out[0] * list_width_out[0] * length_ * 3
      );
      delete [] list_clip_data[0];
      list_clip_data[0] = nullptr;
    }
  } // if crop_ <= 0
  // ------------------------

  // If the context is not CPUContext, we will need to do a copy in the
  // prefetch function as well.
  if (!std::is_same<Context, CPUContext>::value) {
    prefetched_clip_on_device_.CopyFrom(prefetched_clip_, &context_);
    prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
  }
  return true;
}