in caffe2_customized_ops/video/customized_video_input_op.h [491:622]
bool CustomizedVideoInputOp<Context>::Prefetch() {
// We will get the reader pointer from input.
// If we use local clips, db will store the list
reader_ = &OperatorBase::Input<db::DBReader>(0);
const int channels = 3;
// Call mutable_data() once to allocate the underlying memory.
prefetched_clip_.mutable_data<float>();
prefetched_label_.mutable_data<int>();
// Prefetching handled with a thread pool of "decode_threads" threads.
std::mt19937 meta_randgen(time(nullptr));
std::vector<std::mt19937> randgen_per_thread;
for (int i = 0; i < num_decode_threads_; ++i) {
randgen_per_thread.emplace_back(meta_randgen());
}
std::bernoulli_distribution mirror_this_clip(0.5);
const int num_items = batch_size_;
// ------------ only useful for crop_ <= 0
std::vector<float*> list_clip_data;
std::vector<int> list_height_out;
std::vector<int> list_width_out;
list_clip_data.resize(num_items);
list_height_out.resize(num_items);
list_width_out.resize(num_items);
const int MAX_IMAGE_SIZE = 500 * 500;
if (crop_ <= 0) {
for (int item_id = 0; item_id < num_items; ++item_id) {
const int num_clips = 1;
/*
we have to allocate outside of DecodeAndTransform,
because DecodeAndTransform does not change the values.
*/
list_clip_data[item_id] =
new float[num_clips * MAX_IMAGE_SIZE * length_ * 3];
list_height_out[item_id] = -1;
list_width_out[item_id] = -1;
} // for
} //if
// ------------------------
for (int item_id = 0; item_id < num_items; ++item_id) {
std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];
// get the label data pointer for the item_id -th example
int* label_data = prefetched_label_.mutable_data<int>() +
(multiple_label_ ? num_of_labels_ : 1) * item_id;
// float* clip_data = prefetched_clip_.mutable_data<float>() +
// crop_ * crop_ * length_ * channels * item_id;
std::string key, value;
// read data
reader_->Read(&key, &value);
thread_pool_->runTask(std::bind(
&CustomizedVideoInputOp<Context>::DecodeAndTransform,
this,
std::string(value),
(crop_ > 0) ?
(prefetched_clip_.mutable_data<float>() +
crop_ * crop_ * length_ * channels * item_id) // clip_data
: (list_clip_data[item_id]), // temp list
label_data,
crop_,
mirror_,
mean_,
std_,
randgen,
&mirror_this_clip,
&(list_height_out[item_id]),
&(list_width_out[item_id])
));
} // for over the batch
thread_pool_->waitWorkComplete();
// ------------ only useful for crop_ <= 0
if (crop_ <= 0) { // There should be only one item
if (num_items != 1) {
LOG(FATAL) << "There should be only one item.";
}
if (MAX_IMAGE_SIZE < list_height_out[0] * list_width_out[0]) {
LOG(FATAL) << "Buffer is too small.";
}
// reallocate
/*
The network is usually designed for 224x224 input. If the empty image is
smaller than this size, the network run can crash (e.g., kernel > space)
*/
const int MIN_SIZE = 224;
vector<TIndex> data_shape(5);
data_shape[0] = batch_size_;
data_shape[1] = 3;
data_shape[2] = length_;
data_shape[3] = std::max(list_height_out[0], MIN_SIZE); // for safety
data_shape[4] = std::max(list_width_out[0], MIN_SIZE); // for safety
prefetched_clip_.Resize(data_shape);
prefetched_clip_.mutable_data<float>();
if (list_height_out[0] < MIN_SIZE || list_width_out[0] < MIN_SIZE) {
LOG(ERROR) << "Video is too small.";
}
// in case of empty video, initialize an all-zero blob
memset(prefetched_clip_.mutable_data<float>(), 0,
sizeof(float) * prefetched_clip_.size());
if (list_clip_data[0] != nullptr
&& list_height_out[0] > 0 && list_width_out[0] > 0) {
const int num_clips = batch_size_;
memcpy(
prefetched_clip_.mutable_data<float>(),
list_clip_data[0],
sizeof(float) * num_clips *
list_height_out[0] * list_width_out[0] * length_ * 3
);
delete [] list_clip_data[0];
list_clip_data[0] = nullptr;
}
} // if crop_ <= 0
// ------------------------
// If the context is not CPUContext, we will need to do a copy in the
// prefetch function as well.
if (!std::is_same<Context, CPUContext>::value) {
prefetched_clip_on_device_.CopyFrom(prefetched_clip_, &context_);
prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
}
return true;
}