1 #ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_ 2 #define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_ 9 #include <caffe2/core/db.h> 10 #include <caffe2/core/logging.h> 11 #include <caffe2/operators/prefetch_op.h> 12 #include <caffe2/utils/math.h> 13 #include <caffe2/utils/thread_pool.h> 14 #include <caffe2/video/video_io.h> 18 template <
class Context>
21 using OperatorBase::OutputSize;
30 bool Prefetch()
override;
31 bool CopyPrefetched()
override;
34 void CheckParamsAndPrint();
36 bool GetClipsAndLabelsFromDBValue(
37 const std::string& value,
40 std::vector<unsigned char*>& buffer_rgb,
44 void DecodeAndTransform(
45 const std::string& value,
50 std::mt19937* randgen,
51 std::bernoulli_distribution* mirror_this_clip);
65 std::vector<float> mean_rgb_;
66 std::vector<float> inv_std_rgb_;
67 std::vector<float> mean_of_;
68 std::vector<float> inv_std_of_;
78 int sampling_rate_rgb_;
80 float img_saturation_;
81 float img_brightness_;
84 float color_lighting_std_;
85 std::vector<std::vector<float>> color_lighting_eigvecs_;
86 std::vector<float> color_lighting_eigvals_;
87 int num_of_required_frame_;
89 int sampling_rate_of_;
96 int multi_crop_count_;
101 bool do_flow_aggregation_;
103 bool get_optical_flow_;
105 bool do_multi_label_;
108 int num_decode_threads_;
109 std::shared_ptr<TaskThreadPool> thread_pool_;
112 template <
class Context>
115 CAFFE_ENFORCE_GT(batch_size_, 0,
"Batch size should be positive.");
117 clip_per_video_, 0,
"Number of clips per video should be positive.");
118 CAFFE_ENFORCE_GT(crop_height_, 0,
"Must provide the cropping height value.");
119 CAFFE_ENFORCE_GT(crop_width_, 0,
"Must provide the cropping width value.");
122 num_of_required_frame_, 0,
"Required number of frames must be positive.");
124 if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
125 CAFFE_ENFORCE_GT(height_min_, 0,
"Must provide the minimal height value.");
126 CAFFE_ENFORCE_GT(width_min_, 0,
"Must provide the minimal width value.");
130 "The minimal height must be no smaller than the cropping height.");
134 "The minimal width must be no smaller than the cropping width.");
135 }
else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
136 CAFFE_ENFORCE_GT(scale_h_, 0,
"Must provide the scale height value.");
137 CAFFE_ENFORCE_GT(scale_w_, 0,
"Must provide the scale width value.");
141 "The scaled height must be no smaller than the cropping height.");
145 "The scaled width must be no smaller than the cropping width.");
149 CAFFE_ENFORCE_GT(length_rgb_, 0,
"Must provide rgb clip length.");
151 sampling_rate_rgb_, 0,
"4 frames for mc2; 2 frames for res3d.");
153 channels_rgb_, mean_rgb_.size(),
"Number rgb channels is wrong!");
155 channels_rgb_, inv_std_rgb_.size(),
"Number rgb channels is wrong!");
158 if (get_optical_flow_) {
159 CAFFE_ENFORCE_GT(length_of_, 0,
"Must provide optical flow clip length.");
161 sampling_rate_of_, 0,
"4 frames for mc2; 2 frames for res3d.");
165 "Number of optical flow channels is wrong!");
169 "Number of optical flow channels is wrong!");
172 if (clip_per_video_ > 1) {
175 DecodeType::DO_UNIFORM_SMP,
176 "Only uniformly sampling is supported when sampling multiple clips!");
179 if (do_multi_label_) {
183 "Number of classes must be set when using multiple labels.");
187 LOG(INFO) <<
"Creating a clip input op with the following setting: ";
188 LOG(INFO) <<
" Using " << num_decode_threads_ <<
" CPU threads;";
189 LOG(INFO) <<
" Outputting in batches of " << batch_size_ <<
" videos;";
190 LOG(INFO) <<
" Each video has " << clip_per_video_ <<
" clips;";
191 LOG(INFO) <<
" Scaling image to " << scale_h_ <<
"x" << scale_w_;
192 LOG(INFO) <<
" (Height, Width) is at least (" << height_min_ <<
", " 193 << width_min_ <<
")";
194 LOG(INFO) <<
" Cropping video frame to " << crop_height_ <<
"x" 195 << crop_width_ << (random_mirror_ ?
" with " :
" without ")
196 <<
"random mirroring;";
197 LOG(INFO) <<
" Using " << (random_crop_ ?
"random" :
"center") <<
" crop";
198 LOG(INFO) <<
" Is multi-cropping enabled: " << multi_crop_;
201 LOG(INFO) <<
" Using a clip of " << length_rgb_ <<
" rgb frames " 202 <<
"with " << channels_rgb_ <<
" channels " 203 <<
"and a sampling rate of 1:" << sampling_rate_rgb_;
204 LOG(INFO) <<
" RGB data augmentation. Color jittering: " << color_jitter_
205 <<
". Color lighting: " << color_lighting_;
206 for (
int i = 0; i < channels_rgb_; i++) {
207 LOG(INFO) <<
" RGB " << i <<
"-th channel mean: " << mean_rgb_[i]
208 <<
" std: " << 1.f / inv_std_rgb_[i];
212 if (get_optical_flow_) {
213 LOG(INFO) <<
" Using a clip of " << length_of_ <<
" optical flow frames " 214 <<
"with " << channels_of_ <<
" channels " 215 <<
"and a sampling rate of 1:" << sampling_rate_of_
216 <<
" flow_data_type_: " << flow_data_type_
217 <<
" flow_alg_type_: " << flow_alg_type_;
218 for (
int i = 0; i < channels_of_; i++) {
219 LOG(INFO) <<
" Optical flow" << i
220 <<
"-th channel mean: " << mean_of_[i]
221 <<
" std: " << 1.f / inv_std_of_[i];
225 if (video_res_type_ == VideoResType::ORIGINAL_RES) {
226 LOG(INFO) <<
" Use original resolution";
227 }
else if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
228 LOG(INFO) <<
" Resize with minimal size and keep aspect ratio";
229 }
else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
230 LOG(INFO) <<
" Resize and ignore aspect ratio";
232 LOG(ERROR) <<
" Unknown video resolution type";
235 if (decode_type_ == DecodeType::DO_TMP_JITTER) {
236 LOG(INFO) <<
" Do temporal jittering";
237 }
else if (decode_type_ == DecodeType::USE_START_FRM) {
238 LOG(INFO) <<
" Use start_frm for decoding";
239 }
else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
240 LOG(INFO) <<
" Do uniformly sampling";
242 LOG(ERROR) <<
" Unknown video decoding type";
246 template <
class Context>
248 const OperatorDef& operator_def,
253 OperatorBase::template GetSingleArgument<int>(
"batch_size", 0)),
255 OperatorBase::template GetSingleArgument<int>(
"clip_per_video", 1)),
256 mean_rgb_(OperatorBase::template GetRepeatedArgument<float>(
257 "mean_rgb_per_channel",
258 {OperatorBase::template GetSingleArgument<float>(
"mean_rgb", 128.)})),
259 inv_std_rgb_(OperatorBase::template GetRepeatedArgument<float>(
260 "std_rgb_per_channel",
261 {OperatorBase::template GetSingleArgument<float>(
"std_rgb", 1.)})),
262 mean_of_(OperatorBase::template GetRepeatedArgument<float>(
263 "mean_of_per_channel",
264 {OperatorBase::template GetSingleArgument<float>(
"mean_of", 0.)})),
265 inv_std_of_(OperatorBase::template GetRepeatedArgument<float>(
266 "std_of_per_channel",
267 {OperatorBase::template GetSingleArgument<float>(
"std_of", 1.)})),
269 OperatorBase::template GetSingleArgument<int>(
"channels_rgb", 3)),
271 OperatorBase::template GetSingleArgument<int>(
"channels_of", 2)),
272 crop_height_(OperatorBase::template GetSingleArgument<int>(
274 {OperatorBase::template GetSingleArgument<int>(
"crop_size", 0.)})),
275 crop_width_(OperatorBase::template GetSingleArgument<int>(
277 {OperatorBase::template GetSingleArgument<int>(
"crop_size", 0.)})),
278 scale_h_(OperatorBase::template GetSingleArgument<int>(
"scale_h", 0)),
279 scale_w_(OperatorBase::template GetSingleArgument<int>(
"scale_w", 0)),
280 height_min_(OperatorBase::template GetSingleArgument<int>(
282 {OperatorBase::template GetSingleArgument<int>(
"short_edge", 0)})),
283 width_min_(OperatorBase::template GetSingleArgument<int>(
285 {OperatorBase::template GetSingleArgument<int>(
"short_edge", 0)})),
287 OperatorBase::template GetSingleArgument<int>(
"length_rgb", 0)),
288 sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
291 color_jitter_(OperatorBase::template GetSingleArgument<bool>(
294 img_saturation_(OperatorBase::template GetSingleArgument<float>(
297 img_brightness_(OperatorBase::template GetSingleArgument<float>(
301 OperatorBase::template GetSingleArgument<float>(
"img_contrast", 0.4)),
302 color_lighting_(OperatorBase::template GetSingleArgument<bool>(
305 color_lighting_std_(OperatorBase::template GetSingleArgument<float>(
306 "color_lighting_std",
308 length_of_(OperatorBase::template GetSingleArgument<int>(
"length_of", 0)),
310 OperatorBase::template GetSingleArgument<int>(
"sampling_rate_of", 1)),
312 OperatorBase::template GetSingleArgument<int>(
"frame_gap_of", 1)),
313 random_mirror_(OperatorBase::template GetSingleArgument<bool>(
317 OperatorBase::template GetSingleArgument<int>(
"num_of_class", 0)),
318 use_local_file_(OperatorBase::template GetSingleArgument<bool>(
322 OperatorBase::template GetSingleArgument<bool>(
"random_crop",
true)),
324 OperatorBase::template GetSingleArgument<bool>(
"multi_crop",
false)),
326 OperatorBase::template GetSingleArgument<int>(
"flow_data_type", 0)),
328 OperatorBase::template GetSingleArgument<int>(
"flow_alg_type", 0)),
330 OperatorBase::template GetSingleArgument<int>(
"decode_type", 0)),
332 OperatorBase::template GetSingleArgument<int>(
"video_res_type", 0)),
333 do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
334 "do_flow_aggregation",
336 get_rgb_(OperatorBase::template GetSingleArgument<bool>(
"get_rgb",
true)),
337 get_optical_flow_(OperatorBase::template GetSingleArgument<bool>(
340 get_video_id_(OperatorBase::template GetSingleArgument<bool>(
343 do_multi_label_(OperatorBase::template GetSingleArgument<bool>(
346 num_decode_threads_(OperatorBase::template GetSingleArgument<int>(
347 "num_decode_threads",
349 thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)) {
351 color_lighting_eigvecs_.push_back(
352 std::vector<float>{-144.7125, 183.396, 102.2295});
353 color_lighting_eigvecs_.push_back(
354 std::vector<float>{-148.104, -1.1475, -207.57});
355 color_lighting_eigvecs_.push_back(
356 std::vector<float>{-148.818, -177.174, 107.1765});
358 color_lighting_eigvals_ = std::vector<float>{0.2175, 0.0188, 0.0045};
361 multi_crop_count_ = 1;
366 multi_crop_count_ = 14;
369 num_of_required_frame_ = 0;
384 num_of_required_frame_ = std::max(
385 num_of_required_frame_, (length_rgb_ - 1) * sampling_rate_rgb_ + 1);
392 "The mean and std. vectors for RGB must be of the same size.");
393 if (mean_rgb_.size() == 1) {
394 mean_rgb_.resize(3, mean_rgb_[0]);
395 inv_std_rgb_.resize(3, inv_std_rgb_[0]);
397 CAFFE_ENFORCE_EQ(mean_rgb_.size(), 3,
"RGB should have 3 channels");
398 for (
int i = 0; i < 3; ++i) {
399 inv_std_rgb_[i] = 1.f / inv_std_rgb_[i];
403 if (get_optical_flow_) {
405 num_of_required_frame_ = std::max(
406 num_of_required_frame_,
407 (length_of_ - 1) * sampling_rate_of_ + frame_gap_of_ + 1);
412 "The mean and std. vectors for Optical Flow must be of the same size.");
414 switch (flow_data_type_) {
416 case FlowDataType::Flow2C:
420 case FlowDataType::Flow3C:
425 case FlowDataType::FlowWithGray:
430 case FlowDataType::FlowWithRGB:
434 LOG(ERROR) <<
"Unknown optical flow type " << flow_data_type_;
437 LOG(INFO) <<
"channels_of_: " << channels_of_;
438 if (mean_of_.size() == 1) {
439 mean_of_.resize(channels_of_, mean_of_[0]);
440 inv_std_of_.resize(channels_of_, inv_std_of_[0]);
442 for (
int i = 0; i < channels_of_; ++i) {
443 inv_std_of_[i] = 1.f / inv_std_of_[i];
447 CheckParamsAndPrint();
450 operator_def.input_size(), 0,
"Need to have a DBReader blob input");
452 vector<TIndex> data_shape(5);
453 vector<TIndex> label_shape(2);
456 data_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
457 data_shape[1] = channels_rgb_;
458 data_shape[2] = length_rgb_;
459 data_shape[3] = crop_height_;
460 data_shape[4] = crop_width_;
461 prefetched_clip_rgb_.
Resize(data_shape);
464 data_shape[1] = channels_of_;
465 data_shape[2] = length_of_;
466 prefetched_clip_of_.
Resize(data_shape);
470 if (do_multi_label_) {
471 label_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
472 label_shape[1] = num_of_class_;
473 prefetched_label_.
Resize(label_shape);
476 vector<TIndex>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
479 prefetched_video_id_.
Resize(
480 vector<TIndex>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
483 template <
class Context>
485 const std::string& value,
488 std::vector<unsigned char*>& buffer_rgb,
490 int* video_id_data) {
492 int curr_proto_idx = 0;
493 CAFFE_ENFORCE(protos.ParseFromString(value));
494 const TensorProto& video_proto = protos.protos(curr_proto_idx++);
495 const TensorProto& label_proto = protos.protos(curr_proto_idx++);
500 if (decode_type_ == DecodeType::USE_START_FRM) {
503 protos.protos_size(),
504 "No proto is found for starting frame");
505 const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
506 start_frm = start_frm_proto.int32_data(0);
510 curr_proto_idx, protos.protos_size(),
"No proto is found for video id");
511 const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
512 for (
int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
513 video_id_data[i] = video_id_proto.int64_data(0);
517 if (!do_multi_label_) {
518 for (
int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
519 label_data[i] = label_proto.int32_data(0);
527 sizeof(
int) * num_of_class_ * multi_crop_count_ * clip_per_video_);
528 for (
int i = 0; i < clip_per_video_; i++) {
529 for (
int j = 0; j < multi_crop_count_; ++j) {
530 for (
int k = 0; k < label_proto.int32_data_size(); k++) {
532 [(i * multi_crop_count_ + j) * num_of_class_ +
533 label_proto.int32_data(k)] = 1;
539 if (use_local_file_) {
541 video_proto.data_type(),
543 "Database with a file_list is expected to be string data");
548 params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
549 params.video_res_type_ = video_res_type_;
550 params.crop_height_ = crop_height_;
551 params.crop_width_ = crop_width_;
552 params.height_min_ = height_min_;
553 params.width_min_ = width_min_;
554 params.scale_w_ = scale_w_;
555 params.scale_h_ = scale_h_;
556 params.decode_type_ = decode_type_;
557 params.num_of_required_frame_ = num_of_required_frame_;
559 char* video_buffer =
nullptr;
560 std::string video_filename;
561 int encoded_size = 0;
562 if (video_proto.data_type() == TensorProto::STRING) {
563 const string& encoded_video_str = video_proto.string_data(0);
564 if (!use_local_file_) {
565 encoded_size = encoded_video_str.size();
566 video_buffer =
const_cast<char*
>(encoded_video_str.data());
568 video_filename = encoded_video_str;
570 }
else if (video_proto.data_type() == TensorProto::BYTE) {
571 if (!use_local_file_) {
572 encoded_size = video_proto.byte_data().size();
573 video_buffer =
const_cast<char*
>(video_proto.byte_data().data());
576 video_filename = video_proto.string_data(0);
579 LOG(FATAL) <<
"Unknown video data type.";
582 DecodeMultipleClipsFromVideo(
597 template <
class Context>
599 const std::string& value,
600 float* clip_rgb_data,
604 std::mt19937* randgen,
605 std::bernoulli_distribution* mirror_this_clip) {
606 std::vector<unsigned char*> buffer_rgb;
611 CHECK(GetClipsAndLabelsFromDBValue(
612 value, height, width, buffer_rgb, label_data, video_id_data));
613 int clip_offset_rgb = multi_crop_count_ * channels_rgb_ * length_rgb_ *
614 crop_height_ * crop_width_;
615 int clip_crop_offset_of =
616 channels_of_ * length_of_ * crop_height_ * crop_width_;
617 int clip_offset_of = multi_crop_count_ * clip_crop_offset_of;
618 for (
int i = 0; i < std::min(clip_per_video_,
int(buffer_rgb.size())); i++) {
625 std::uniform_int_distribution<>(0, height - crop_height_)(*randgen);
626 w_off = std::uniform_int_distribution<>(0, width - crop_width_)(*randgen);
629 h_off = (height - crop_height_) / 2;
630 w_off = (width - crop_width_) / 2;
637 int multi_crop_w_off[7] = {0,
638 (width - crop_width_) / 2,
640 (width - crop_width_) / 2,
642 (width - crop_width_) / 2,
643 width - crop_width_};
644 int multi_crop_h_off[7] = {0,
647 (height - crop_height_) / 2,
648 height - crop_height_,
649 height - crop_height_,
650 height - crop_height_};
653 bool mirror_me = random_mirror_ && (*mirror_this_clip)(*randgen);
654 if (get_rgb_ && clip_rgb_data) {
676 color_lighting_eigvecs_,
677 color_lighting_eigvals_,
681 clip_rgb_data + (i * clip_offset_rgb));
683 if (get_optical_flow_ && clip_of_data) {
685 for (
int j = 0; j < multi_crop_count_; ++j) {
686 if (multi_crop_count_ == 1) {
687 rect = cv::Rect(w_off, h_off, crop_width_, crop_height_);
689 mirror_me = j / (multi_crop_count_ / 2);
690 int k = j % (multi_crop_count_ / 2);
697 ClipTransformOpticalFlow(
712 do_flow_aggregation_,
715 clip_of_data + (i * clip_offset_of) + j * clip_crop_offset_of);
720 if (buffer_rgb.size() > 0) {
721 for (
int i = 0; i < buffer_rgb.size(); i++) {
722 unsigned char* buff = buffer_rgb[i];
729 template <
class Context>
733 reader_ = &OperatorBase::Input<db::DBReader>(0);
742 std::mt19937 meta_randgen(time(
nullptr));
743 std::vector<std::mt19937> randgen_per_thread;
744 for (
int i = 0; i < num_decode_threads_; ++i) {
745 randgen_per_thread.emplace_back(meta_randgen());
748 std::bernoulli_distribution mirror_this_clip(0.5);
749 for (
int item_id = 0; item_id < batch_size_; ++item_id) {
750 std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];
752 int frame_size = crop_height_ * crop_width_;
754 float* clip_rgb_data = prefetched_clip_rgb_.
mutable_data<
float>() +
755 frame_size * length_rgb_ * channels_rgb_ * item_id * clip_per_video_ *
759 float* clip_of_data = prefetched_clip_of_.
mutable_data<
float>() +
760 frame_size * length_of_ * channels_of_ * item_id * clip_per_video_ *
764 int* label_data = prefetched_label_.
mutable_data<
int>() +
765 (do_multi_label_ ? num_of_class_ : 1) * item_id * clip_per_video_ *
769 int* video_id_data = prefetched_video_id_.
mutable_data<
int>() +
770 item_id * clip_per_video_ * multi_crop_count_;
772 std::string key, value;
774 reader_->
Read(&key, &value);
776 thread_pool_->runTask(std::bind(
787 thread_pool_->waitWorkComplete();
791 if (!std::is_same<Context, CPUContext>::value) {
793 prefetched_clip_rgb_on_device_.CopyFrom(prefetched_clip_rgb_, &context_);
795 if (get_optical_flow_) {
796 prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
798 prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
800 prefetched_video_id_on_device_.CopyFrom(prefetched_video_id_, &context_);
806 template <
class Context>
810 auto* clip_rgb_output = OperatorBase::Output<Tensor<Context>>(index++);
811 if (std::is_same<Context, CPUContext>::value) {
812 clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
814 clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
817 if (get_optical_flow_) {
818 auto* clip_of_output = OperatorBase::Output<Tensor<Context>>(index++);
819 if (std::is_same<Context, CPUContext>::value) {
820 clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
822 clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
825 auto* label_output = OperatorBase::Output<Tensor<Context>>(index++);
826 if (std::is_same<Context, CPUContext>::value) {
827 label_output->CopyFrom(prefetched_label_, &context_);
829 label_output->CopyFrom(prefetched_label_on_device_, &context_);
832 auto* video_id_output = OperatorBase::Output<Tensor<Context>>(index);
833 if (std::is_same<Context, CPUContext>::value) {
834 video_id_output->CopyFrom(prefetched_video_id_, &context_);
836 video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
844 #endif // CAFFE2_VIDEO_VIDEO_INPUT_OP_H_ void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
A reader wrapper for DB that also allows us to serialize it.
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
T * mutable_data()
Returns a typed pointer of the underlying storage.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...