Caffe2 - C++ API
A deep learning, cross platform ML framework
video_input_op.h
1 #ifndef CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
2 #define CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
3 
4 #include <istream>
5 #include <ostream>
6 #include <random>
7 #include <string>
8 
9 #include <caffe2/core/db.h>
10 #include <caffe2/core/logging.h>
11 #include <caffe2/operators/prefetch_op.h>
12 #include <caffe2/utils/math.h>
13 #include <caffe2/utils/thread_pool.h>
14 #include <caffe2/video/video_io.h>
15 
16 namespace caffe2 {
17 
18 template <class Context>
19 class VideoInputOp final : public PrefetchOperator<Context> {
20  public:
21  using OperatorBase::OutputSize;
24  explicit VideoInputOp(const OperatorDef& operator_def, Workspace* ws);
25  ~VideoInputOp() {
27  }
28 
29  // override methods
30  bool Prefetch() override;
31  bool CopyPrefetched() override;
32 
33  private:
34  void CheckParamsAndPrint();
35 
36  bool GetClipsAndLabelsFromDBValue(
37  const std::string& value,
38  int& height,
39  int& width,
40  std::vector<unsigned char*>& buffer_rgb,
41  int* label_data,
42  int* video_id_data);
43 
44  void DecodeAndTransform(
45  const std::string& value,
46  float* clip_rgb_data,
47  float* clip_of_data,
48  int* label_data,
49  int* video_id_data,
50  std::mt19937* randgen,
51  std::bernoulli_distribution* mirror_this_clip);
52 
53  const db::DBReader* reader_;
54  CPUContext cpu_context_;
55  TensorCPU prefetched_clip_rgb_;
56  TensorCPU prefetched_clip_of_;
57  TensorCPU prefetched_label_;
58  TensorCPU prefetched_video_id_;
59  Tensor<Context> prefetched_clip_rgb_on_device_;
60  Tensor<Context> prefetched_clip_of_on_device_;
61  Tensor<Context> prefetched_label_on_device_;
62  Tensor<Context> prefetched_video_id_on_device_;
63  int batch_size_;
64  int clip_per_video_;
65  std::vector<float> mean_rgb_;
66  std::vector<float> inv_std_rgb_;
67  std::vector<float> mean_of_;
68  std::vector<float> inv_std_of_;
69  int channels_rgb_;
70  int channels_of_;
71  int crop_height_;
72  int crop_width_;
73  int scale_h_;
74  int scale_w_;
75  int height_min_;
76  int width_min_;
77  int length_rgb_;
78  int sampling_rate_rgb_;
79  bool color_jitter_;
80  float img_saturation_;
81  float img_brightness_;
82  float img_contrast_;
83  bool color_lighting_;
84  float color_lighting_std_;
85  std::vector<std::vector<float>> color_lighting_eigvecs_;
86  std::vector<float> color_lighting_eigvals_;
87  int num_of_required_frame_;
88  int length_of_;
89  int sampling_rate_of_;
90  int frame_gap_of_;
91  bool random_mirror_;
92  int num_of_class_;
93  bool use_local_file_;
94  bool random_crop_;
95  bool multi_crop_;
96  int multi_crop_count_;
97  int flow_data_type_;
98  int flow_alg_type_;
99  int decode_type_;
100  int video_res_type_;
101  bool do_flow_aggregation_;
102  bool get_rgb_;
103  bool get_optical_flow_;
104  bool get_video_id_;
105  bool do_multi_label_;
106 
107  // thread pool for parse + decode
108  int num_decode_threads_;
109  std::shared_ptr<TaskThreadPool> thread_pool_;
110 };
111 
112 template <class Context>
114  // check whether the input parameters are valid or not
115  CAFFE_ENFORCE_GT(batch_size_, 0, "Batch size should be positive.");
116  CAFFE_ENFORCE_GT(
117  clip_per_video_, 0, "Number of clips per video should be positive.");
118  CAFFE_ENFORCE_GT(crop_height_, 0, "Must provide the cropping height value.");
119  CAFFE_ENFORCE_GT(crop_width_, 0, "Must provide the cropping width value.");
120 
121  CAFFE_ENFORCE_GT(
122  num_of_required_frame_, 0, "Required number of frames must be positive.");
123 
124  if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
125  CAFFE_ENFORCE_GT(height_min_, 0, "Must provide the minimal height value.");
126  CAFFE_ENFORCE_GT(width_min_, 0, "Must provide the minimal width value.");
127  CAFFE_ENFORCE_GE(
128  height_min_,
129  crop_height_,
130  "The minimal height must be no smaller than the cropping height.");
131  CAFFE_ENFORCE_GE(
132  width_min_,
133  crop_width_,
134  "The minimal width must be no smaller than the cropping width.");
135  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
136  CAFFE_ENFORCE_GT(scale_h_, 0, "Must provide the scale height value.");
137  CAFFE_ENFORCE_GT(scale_w_, 0, "Must provide the scale width value.");
138  CAFFE_ENFORCE_GE(
139  scale_h_,
140  crop_height_,
141  "The scaled height must be no smaller than the cropping height.");
142  CAFFE_ENFORCE_GE(
143  scale_w_,
144  crop_width_,
145  "The scaled width must be no smaller than the cropping width.");
146  }
147 
148  if (get_rgb_) {
149  CAFFE_ENFORCE_GT(length_rgb_, 0, "Must provide rgb clip length.");
150  CAFFE_ENFORCE_GT(
151  sampling_rate_rgb_, 0, "4 frames for mc2; 2 frames for res3d.");
152  CAFFE_ENFORCE_EQ(
153  channels_rgb_, mean_rgb_.size(), "Number rgb channels is wrong!");
154  CAFFE_ENFORCE_EQ(
155  channels_rgb_, inv_std_rgb_.size(), "Number rgb channels is wrong!");
156  }
157 
158  if (get_optical_flow_) {
159  CAFFE_ENFORCE_GT(length_of_, 0, "Must provide optical flow clip length.");
160  CAFFE_ENFORCE_GT(
161  sampling_rate_of_, 0, "4 frames for mc2; 2 frames for res3d.");
162  CAFFE_ENFORCE_EQ(
163  channels_of_,
164  mean_of_.size(),
165  "Number of optical flow channels is wrong!");
166  CAFFE_ENFORCE_EQ(
167  channels_of_,
168  inv_std_of_.size(),
169  "Number of optical flow channels is wrong!");
170  }
171 
172  if (clip_per_video_ > 1) {
173  CAFFE_ENFORCE_EQ(
174  decode_type_,
175  DecodeType::DO_UNIFORM_SMP,
176  "Only uniformly sampling is supported when sampling multiple clips!");
177  }
178 
179  if (do_multi_label_) {
180  CAFFE_ENFORCE_GT(
181  num_of_class_,
182  0,
183  "Number of classes must be set when using multiple labels.");
184  }
185 
186  // print out the parameter settings
187  LOG(INFO) << "Creating a clip input op with the following setting: ";
188  LOG(INFO) << " Using " << num_decode_threads_ << " CPU threads;";
189  LOG(INFO) << " Outputting in batches of " << batch_size_ << " videos;";
190  LOG(INFO) << " Each video has " << clip_per_video_ << " clips;";
191  LOG(INFO) << " Scaling image to " << scale_h_ << "x" << scale_w_;
192  LOG(INFO) << " (Height, Width) is at least (" << height_min_ << ", "
193  << width_min_ << ")";
194  LOG(INFO) << " Cropping video frame to " << crop_height_ << "x"
195  << crop_width_ << (random_mirror_ ? " with " : " without ")
196  << "random mirroring;";
197  LOG(INFO) << " Using " << (random_crop_ ? "random" : "center") << " crop";
198  LOG(INFO) << " Is multi-cropping enabled: " << multi_crop_;
199 
200  if (get_rgb_) {
201  LOG(INFO) << " Using a clip of " << length_rgb_ << " rgb frames "
202  << "with " << channels_rgb_ << " channels "
203  << "and a sampling rate of 1:" << sampling_rate_rgb_;
204  LOG(INFO) << " RGB data augmentation. Color jittering: " << color_jitter_
205  << ". Color lighting: " << color_lighting_;
206  for (int i = 0; i < channels_rgb_; i++) {
207  LOG(INFO) << " RGB " << i << "-th channel mean: " << mean_rgb_[i]
208  << " std: " << 1.f / inv_std_rgb_[i];
209  }
210  }
211 
212  if (get_optical_flow_) {
213  LOG(INFO) << " Using a clip of " << length_of_ << " optical flow frames "
214  << "with " << channels_of_ << " channels "
215  << "and a sampling rate of 1:" << sampling_rate_of_
216  << " flow_data_type_: " << flow_data_type_
217  << " flow_alg_type_: " << flow_alg_type_;
218  for (int i = 0; i < channels_of_; i++) {
219  LOG(INFO) << " Optical flow" << i
220  << "-th channel mean: " << mean_of_[i]
221  << " std: " << 1.f / inv_std_of_[i];
222  }
223  }
224 
225  if (video_res_type_ == VideoResType::ORIGINAL_RES) {
226  LOG(INFO) << " Use original resolution";
227  } else if (video_res_type_ == VideoResType::USE_MINIMAL_WIDTH_HEIGHT) {
228  LOG(INFO) << " Resize with minimal size and keep aspect ratio";
229  } else if (video_res_type_ == VideoResType::USE_WIDTH_HEIGHT) {
230  LOG(INFO) << " Resize and ignore aspect ratio";
231  } else {
232  LOG(ERROR) << " Unknown video resolution type";
233  }
234 
235  if (decode_type_ == DecodeType::DO_TMP_JITTER) {
236  LOG(INFO) << " Do temporal jittering";
237  } else if (decode_type_ == DecodeType::USE_START_FRM) {
238  LOG(INFO) << " Use start_frm for decoding";
239  } else if (decode_type_ == DecodeType::DO_UNIFORM_SMP) {
240  LOG(INFO) << " Do uniformly sampling";
241  } else {
242  LOG(ERROR) << " Unknown video decoding type";
243  }
244 }
245 
246 template <class Context>
248  const OperatorDef& operator_def,
249  Workspace* ws)
250  : PrefetchOperator<Context>(operator_def, ws),
251  reader_(nullptr),
252  batch_size_(
253  OperatorBase::template GetSingleArgument<int>("batch_size", 0)),
254  clip_per_video_(
255  OperatorBase::template GetSingleArgument<int>("clip_per_video", 1)),
256  mean_rgb_(OperatorBase::template GetRepeatedArgument<float>(
257  "mean_rgb_per_channel",
258  {OperatorBase::template GetSingleArgument<float>("mean_rgb", 128.)})),
259  inv_std_rgb_(OperatorBase::template GetRepeatedArgument<float>(
260  "std_rgb_per_channel",
261  {OperatorBase::template GetSingleArgument<float>("std_rgb", 1.)})),
262  mean_of_(OperatorBase::template GetRepeatedArgument<float>(
263  "mean_of_per_channel",
264  {OperatorBase::template GetSingleArgument<float>("mean_of", 0.)})),
265  inv_std_of_(OperatorBase::template GetRepeatedArgument<float>(
266  "std_of_per_channel",
267  {OperatorBase::template GetSingleArgument<float>("std_of", 1.)})),
268  channels_rgb_(
269  OperatorBase::template GetSingleArgument<int>("channels_rgb", 3)),
270  channels_of_(
271  OperatorBase::template GetSingleArgument<int>("channels_of", 2)),
272  crop_height_(OperatorBase::template GetSingleArgument<int>(
273  "crop_height",
274  {OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
275  crop_width_(OperatorBase::template GetSingleArgument<int>(
276  "crop_width",
277  {OperatorBase::template GetSingleArgument<int>("crop_size", 0.)})),
278  scale_h_(OperatorBase::template GetSingleArgument<int>("scale_h", 0)),
279  scale_w_(OperatorBase::template GetSingleArgument<int>("scale_w", 0)),
280  height_min_(OperatorBase::template GetSingleArgument<int>(
281  "height_min",
282  {OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
283  width_min_(OperatorBase::template GetSingleArgument<int>(
284  "width_min",
285  {OperatorBase::template GetSingleArgument<int>("short_edge", 0)})),
286  length_rgb_(
287  OperatorBase::template GetSingleArgument<int>("length_rgb", 0)),
288  sampling_rate_rgb_(OperatorBase::template GetSingleArgument<int>(
289  "sampling_rate_rgb",
290  1)),
291  color_jitter_(OperatorBase::template GetSingleArgument<bool>(
292  "color_jitter",
293  false)),
294  img_saturation_(OperatorBase::template GetSingleArgument<float>(
295  "img_saturation",
296  0.4)),
297  img_brightness_(OperatorBase::template GetSingleArgument<float>(
298  "img_brightness",
299  0.4)),
300  img_contrast_(
301  OperatorBase::template GetSingleArgument<float>("img_contrast", 0.4)),
302  color_lighting_(OperatorBase::template GetSingleArgument<bool>(
303  "color_lighting",
304  false)),
305  color_lighting_std_(OperatorBase::template GetSingleArgument<float>(
306  "color_lighting_std",
307  0.1)),
308  length_of_(OperatorBase::template GetSingleArgument<int>("length_of", 0)),
309  sampling_rate_of_(
310  OperatorBase::template GetSingleArgument<int>("sampling_rate_of", 1)),
311  frame_gap_of_(
312  OperatorBase::template GetSingleArgument<int>("frame_gap_of", 1)),
313  random_mirror_(OperatorBase::template GetSingleArgument<bool>(
314  "random_mirror",
315  true)),
316  num_of_class_(
317  OperatorBase::template GetSingleArgument<int>("num_of_class", 0)),
318  use_local_file_(OperatorBase::template GetSingleArgument<bool>(
319  "use_local_file",
320  false)),
321  random_crop_(
322  OperatorBase::template GetSingleArgument<bool>("random_crop", true)),
323  multi_crop_(
324  OperatorBase::template GetSingleArgument<bool>("multi_crop", false)),
325  flow_data_type_(
326  OperatorBase::template GetSingleArgument<int>("flow_data_type", 0)),
327  flow_alg_type_(
328  OperatorBase::template GetSingleArgument<int>("flow_alg_type", 0)),
329  decode_type_(
330  OperatorBase::template GetSingleArgument<int>("decode_type", 0)),
331  video_res_type_(
332  OperatorBase::template GetSingleArgument<int>("video_res_type", 0)),
333  do_flow_aggregation_(OperatorBase::template GetSingleArgument<bool>(
334  "do_flow_aggregation",
335  true)),
336  get_rgb_(OperatorBase::template GetSingleArgument<bool>("get_rgb", true)),
337  get_optical_flow_(OperatorBase::template GetSingleArgument<bool>(
338  "get_optical_flow",
339  false)),
340  get_video_id_(OperatorBase::template GetSingleArgument<bool>(
341  "get_video_id",
342  false)),
343  do_multi_label_(OperatorBase::template GetSingleArgument<bool>(
344  "do_multi_label",
345  false)),
346  num_decode_threads_(OperatorBase::template GetSingleArgument<int>(
347  "num_decode_threads",
348  4)),
349  thread_pool_(std::make_shared<TaskThreadPool>(num_decode_threads_)) {
350  // hard-coded PCA eigenvectors and eigenvalues, based on RBG channel order
351  color_lighting_eigvecs_.push_back(
352  std::vector<float>{-144.7125, 183.396, 102.2295});
353  color_lighting_eigvecs_.push_back(
354  std::vector<float>{-148.104, -1.1475, -207.57});
355  color_lighting_eigvecs_.push_back(
356  std::vector<float>{-148.818, -177.174, 107.1765});
357 
358  color_lighting_eigvals_ = std::vector<float>{0.2175, 0.0188, 0.0045};
359 
360  // multi-cropping for testing
361  multi_crop_count_ = 1;
362  if (multi_crop_) {
363  // we take left-top, central-top, right-top, left-bottom, central-bottom,
364  // right-bottom and central-central croppings as well as their mirrorings
365  // In total, 14 croppings
366  multi_crop_count_ = 14;
367  }
368 
369  num_of_required_frame_ = 0;
370 
371  // mean and std for normalizing different optical flow data type;
372  // Example statistics generated from SOA are shown below, and you may
373  // want to change them if you are running on a different dataset;
374 
375  // 7 channels: (flow_x, flow_y, flow_magitude, gray, Red, Green, Blue)
376  // const std::vector<float> InputDataMean =
377  // {0.0046635, 0.0046261, 0.963986, 102.976, 110.201, 100.64, 95.9966};
378  // const std::vector<float> InputDataStd =
379  // {0.972347, 0.755146, 1.43588, 55.3691, 58.1489, 56.4701, 55.3324};
380 
381  // if we need RGB as an input
382  if (get_rgb_) {
383  // how many frames we need for RGB
384  num_of_required_frame_ = std::max(
385  num_of_required_frame_, (length_rgb_ - 1) * sampling_rate_rgb_ + 1);
386 
387  channels_rgb_ = 3;
388 
389  CAFFE_ENFORCE_EQ(
390  mean_rgb_.size(),
391  inv_std_rgb_.size(),
392  "The mean and std. vectors for RGB must be of the same size.");
393  if (mean_rgb_.size() == 1) {
394  mean_rgb_.resize(3, mean_rgb_[0]);
395  inv_std_rgb_.resize(3, inv_std_rgb_[0]);
396  }
397  CAFFE_ENFORCE_EQ(mean_rgb_.size(), 3, "RGB should have 3 channels");
398  for (int i = 0; i < 3; ++i) {
399  inv_std_rgb_[i] = 1.f / inv_std_rgb_[i];
400  }
401  }
402  // if we need optical flow as an input
403  if (get_optical_flow_) {
404  // how many frames we need for optical flow
405  num_of_required_frame_ = std::max(
406  num_of_required_frame_,
407  (length_of_ - 1) * sampling_rate_of_ + frame_gap_of_ + 1);
408 
409  CAFFE_ENFORCE_EQ(
410  mean_of_.size(),
411  inv_std_of_.size(),
412  "The mean and std. vectors for Optical Flow must be of the same size.");
413  // set the parameters for different input data types
414  switch (flow_data_type_) {
415  // (flow_x, flow_y)
416  case FlowDataType::Flow2C:
417  channels_of_ = 2;
418  break;
419  // (flow_x, flow_y, flow_magnitude)
420  case FlowDataType::Flow3C:
421  channels_of_ = 3;
422  break;
423  // early fusion with gray
424  // (flow_x, flow_y, gray)
425  case FlowDataType::FlowWithGray:
426  channels_of_ = 3;
427  break;
428  // early fusion with RGB
429  // (flow_x, flow_y, Red, Green, Blue)
430  case FlowDataType::FlowWithRGB:
431  channels_of_ = 5;
432  break;
433  default:
434  LOG(ERROR) << "Unknown optical flow type " << flow_data_type_;
435  break;
436  }
437  LOG(INFO) << "channels_of_: " << channels_of_;
438  if (mean_of_.size() == 1) {
439  mean_of_.resize(channels_of_, mean_of_[0]);
440  inv_std_of_.resize(channels_of_, inv_std_of_[0]);
441  }
442  for (int i = 0; i < channels_of_; ++i) {
443  inv_std_of_[i] = 1.f / inv_std_of_[i];
444  }
445  }
446 
447  CheckParamsAndPrint();
448  // Always need a dbreader, even when using local video files
449  CAFFE_ENFORCE_GT(
450  operator_def.input_size(), 0, "Need to have a DBReader blob input");
451 
452  vector<TIndex> data_shape(5);
453  vector<TIndex> label_shape(2);
454 
455  // for RGB data
456  data_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
457  data_shape[1] = channels_rgb_;
458  data_shape[2] = length_rgb_;
459  data_shape[3] = crop_height_;
460  data_shape[4] = crop_width_;
461  prefetched_clip_rgb_.Resize(data_shape);
462 
463  // for optical flow data
464  data_shape[1] = channels_of_;
465  data_shape[2] = length_of_;
466  prefetched_clip_of_.Resize(data_shape);
467 
468  // If do_multi_label is used, output label is a binary vector
469  // of length num_of_class indicating which labels present
470  if (do_multi_label_) {
471  label_shape[0] = batch_size_ * clip_per_video_ * multi_crop_count_;
472  label_shape[1] = num_of_class_;
473  prefetched_label_.Resize(label_shape);
474  } else {
475  prefetched_label_.Resize(
476  vector<TIndex>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
477  }
478 
479  prefetched_video_id_.Resize(
480  vector<TIndex>(1, batch_size_ * clip_per_video_ * multi_crop_count_));
481 }
482 
483 template <class Context>
485  const std::string& value,
486  int& height,
487  int& width,
488  std::vector<unsigned char*>& buffer_rgb,
489  int* label_data,
490  int* video_id_data) {
491  TensorProtos protos;
492  int curr_proto_idx = 0;
493  CAFFE_ENFORCE(protos.ParseFromString(value));
494  const TensorProto& video_proto = protos.protos(curr_proto_idx++);
495  const TensorProto& label_proto = protos.protos(curr_proto_idx++);
496 
497  int start_frm = 0;
498  // start_frm is only valid when sampling 1 clip per video without
499  // temporal jitterring
500  if (decode_type_ == DecodeType::USE_START_FRM) {
501  CAFFE_ENFORCE_LT(
502  curr_proto_idx,
503  protos.protos_size(),
504  "No proto is found for starting frame");
505  const TensorProto& start_frm_proto = protos.protos(curr_proto_idx++);
506  start_frm = start_frm_proto.int32_data(0);
507  }
508  if (get_video_id_) {
509  CAFFE_ENFORCE_LT(
510  curr_proto_idx, protos.protos_size(), "No proto is found for video id");
511  const TensorProto& video_id_proto = protos.protos(curr_proto_idx);
512  for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
513  video_id_data[i] = video_id_proto.int64_data(0);
514  }
515  }
516  // assign labels
517  if (!do_multi_label_) {
518  for (int i = 0; i < clip_per_video_ * multi_crop_count_; i++) {
519  label_data[i] = label_proto.int32_data(0);
520  }
521  } else {
522  // For multiple label case, output label is a binary vector
523  // where presented concepts are makred 1
524  memset(
525  label_data,
526  0,
527  sizeof(int) * num_of_class_ * multi_crop_count_ * clip_per_video_);
528  for (int i = 0; i < clip_per_video_; i++) {
529  for (int j = 0; j < multi_crop_count_; ++j) {
530  for (int k = 0; k < label_proto.int32_data_size(); k++) {
531  label_data
532  [(i * multi_crop_count_ + j) * num_of_class_ +
533  label_proto.int32_data(k)] = 1;
534  }
535  }
536  }
537  }
538 
539  if (use_local_file_) {
540  CAFFE_ENFORCE_EQ(
541  video_proto.data_type(),
542  TensorProto::STRING,
543  "Database with a file_list is expected to be string data");
544  }
545 
546  // initializing the decoding params
547  Params params;
548  params.maximumOutputFrames_ = MAX_DECODING_FRAMES;
549  params.video_res_type_ = video_res_type_;
550  params.crop_height_ = crop_height_;
551  params.crop_width_ = crop_width_;
552  params.height_min_ = height_min_;
553  params.width_min_ = width_min_;
554  params.scale_w_ = scale_w_;
555  params.scale_h_ = scale_h_;
556  params.decode_type_ = decode_type_;
557  params.num_of_required_frame_ = num_of_required_frame_;
558 
559  char* video_buffer = nullptr; // for decoding from buffer
560  std::string video_filename; // for decoding from file
561  int encoded_size = 0;
562  if (video_proto.data_type() == TensorProto::STRING) {
563  const string& encoded_video_str = video_proto.string_data(0);
564  if (!use_local_file_) {
565  encoded_size = encoded_video_str.size();
566  video_buffer = const_cast<char*>(encoded_video_str.data());
567  } else {
568  video_filename = encoded_video_str;
569  }
570  } else if (video_proto.data_type() == TensorProto::BYTE) {
571  if (!use_local_file_) {
572  encoded_size = video_proto.byte_data().size();
573  video_buffer = const_cast<char*>(video_proto.byte_data().data());
574  } else {
575  // TODO: does this works?
576  video_filename = video_proto.string_data(0);
577  }
578  } else {
579  LOG(FATAL) << "Unknown video data type.";
580  }
581 
582  DecodeMultipleClipsFromVideo(
583  video_buffer,
584  video_filename,
585  encoded_size,
586  params,
587  start_frm,
588  clip_per_video_,
589  use_local_file_,
590  height,
591  width,
592  buffer_rgb);
593 
594  return true;
595 }
596 
597 template <class Context>
599  const std::string& value,
600  float* clip_rgb_data,
601  float* clip_of_data,
602  int* label_data,
603  int* video_id_data,
604  std::mt19937* randgen,
605  std::bernoulli_distribution* mirror_this_clip) {
606  std::vector<unsigned char*> buffer_rgb;
607  // get the video resolution after decoding
608  int height = 0;
609  int width = 0;
610  // Decode the video from memory or read from a local file
611  CHECK(GetClipsAndLabelsFromDBValue(
612  value, height, width, buffer_rgb, label_data, video_id_data));
613  int clip_offset_rgb = multi_crop_count_ * channels_rgb_ * length_rgb_ *
614  crop_height_ * crop_width_;
615  int clip_crop_offset_of =
616  channels_of_ * length_of_ * crop_height_ * crop_width_;
617  int clip_offset_of = multi_crop_count_ * clip_crop_offset_of;
618  for (int i = 0; i < std::min(clip_per_video_, int(buffer_rgb.size())); i++) {
619  // get the rectangle for cropping
620  int h_off = 0;
621  int w_off = 0;
622  if (random_crop_) {
623  // using random crop for training
624  h_off =
625  std::uniform_int_distribution<>(0, height - crop_height_)(*randgen);
626  w_off = std::uniform_int_distribution<>(0, width - crop_width_)(*randgen);
627  } else {
628  // using center crop for testing
629  h_off = (height - crop_height_) / 2;
630  w_off = (width - crop_width_) / 2;
631  }
632  // cv::Rect rect(w_off, h_off, crop_width_, crop_height_);
633 
634  // Multi cropping: we take left-top, central-top, right-top, left-bottom,
635  // central-bottom, right-bottom and central-central croppings as well as
636  // their mirrorings. In total, 14 croppings
637  int multi_crop_w_off[7] = {0,
638  (width - crop_width_) / 2,
639  width - crop_width_,
640  (width - crop_width_) / 2,
641  0,
642  (width - crop_width_) / 2,
643  width - crop_width_};
644  int multi_crop_h_off[7] = {0,
645  0,
646  0,
647  (height - crop_height_) / 2,
648  height - crop_height_,
649  height - crop_height_,
650  height - crop_height_};
651 
652  // randomly mirror the image or not
653  bool mirror_me = random_mirror_ && (*mirror_this_clip)(*randgen);
654  if (get_rgb_ && clip_rgb_data) {
655  ClipTransformRGB(
656  buffer_rgb[i],
657  multi_crop_count_,
658  crop_height_,
659  crop_width_,
660  length_rgb_,
661  channels_rgb_,
662  sampling_rate_rgb_,
663  height,
664  width,
665  h_off,
666  w_off,
667  multi_crop_h_off,
668  multi_crop_w_off,
669  mirror_me,
670  color_jitter_,
671  img_saturation_,
672  img_brightness_,
673  img_contrast_,
674  color_lighting_,
675  color_lighting_std_,
676  color_lighting_eigvecs_,
677  color_lighting_eigvals_,
678  mean_rgb_,
679  inv_std_rgb_,
680  randgen,
681  clip_rgb_data + (i * clip_offset_rgb));
682  }
683  if (get_optical_flow_ && clip_of_data) {
684  cv::Rect rect;
685  for (int j = 0; j < multi_crop_count_; ++j) {
686  if (multi_crop_count_ == 1) {
687  rect = cv::Rect(w_off, h_off, crop_width_, crop_height_);
688  } else {
689  mirror_me = j / (multi_crop_count_ / 2);
690  int k = j % (multi_crop_count_ / 2);
691  rect = cv::Rect(
692  multi_crop_w_off[k],
693  multi_crop_h_off[k],
694  crop_width_,
695  crop_height_);
696  }
697  ClipTransformOpticalFlow(
698  buffer_rgb[i],
699  crop_height_,
700  crop_width_,
701  length_of_,
702  channels_of_,
703  sampling_rate_of_,
704  height,
705  width,
706  rect,
707  channels_rgb_,
708  mirror_me,
709  flow_alg_type_,
710  flow_data_type_,
711  frame_gap_of_,
712  do_flow_aggregation_,
713  mean_of_,
714  inv_std_of_,
715  clip_of_data + (i * clip_offset_of) + j * clip_crop_offset_of);
716  }
717  }
718  }
719 
720  if (buffer_rgb.size() > 0) {
721  for (int i = 0; i < buffer_rgb.size(); i++) {
722  unsigned char* buff = buffer_rgb[i];
723  delete[] buff;
724  }
725  }
726  buffer_rgb.clear();
727 }
728 
729 template <class Context>
731  // We will get the reader pointer from input.
732  // If we use local clips, db will store the list
733  reader_ = &OperatorBase::Input<db::DBReader>(0);
734 
735  // Call mutable_data() once to allocate the underlying memory.
736  prefetched_clip_rgb_.mutable_data<float>();
737  prefetched_clip_of_.mutable_data<float>();
738  prefetched_label_.mutable_data<int>();
739  prefetched_video_id_.mutable_data<int>();
740 
741  // Prefetching handled with a thread pool of "decode_threads" threads.
742  std::mt19937 meta_randgen(time(nullptr));
743  std::vector<std::mt19937> randgen_per_thread;
744  for (int i = 0; i < num_decode_threads_; ++i) {
745  randgen_per_thread.emplace_back(meta_randgen());
746  }
747 
748  std::bernoulli_distribution mirror_this_clip(0.5);
749  for (int item_id = 0; item_id < batch_size_; ++item_id) {
750  std::mt19937* randgen = &randgen_per_thread[item_id % num_decode_threads_];
751 
752  int frame_size = crop_height_ * crop_width_;
753  // get the clip data pointer for the item_id -th example
754  float* clip_rgb_data = prefetched_clip_rgb_.mutable_data<float>() +
755  frame_size * length_rgb_ * channels_rgb_ * item_id * clip_per_video_ *
756  multi_crop_count_;
757 
758  // get the optical flow data for the current clip
759  float* clip_of_data = prefetched_clip_of_.mutable_data<float>() +
760  frame_size * length_of_ * channels_of_ * item_id * clip_per_video_ *
761  multi_crop_count_;
762 
763  // get the label data pointer for the item_id -th example
764  int* label_data = prefetched_label_.mutable_data<int>() +
765  (do_multi_label_ ? num_of_class_ : 1) * item_id * clip_per_video_ *
766  multi_crop_count_;
767 
768  // get the video id data pointer for the item_id -th example
769  int* video_id_data = prefetched_video_id_.mutable_data<int>() +
770  item_id * clip_per_video_ * multi_crop_count_;
771 
772  std::string key, value;
773  // read data
774  reader_->Read(&key, &value);
775 
776  thread_pool_->runTask(std::bind(
778  this,
779  std::string(value),
780  clip_rgb_data,
781  clip_of_data,
782  label_data,
783  video_id_data,
784  randgen,
785  &mirror_this_clip));
786  } // for over the batch
787  thread_pool_->waitWorkComplete();
788 
789  // If the context is not CPUContext, we will need to do a copy in the
790  // prefetch function as well.
791  if (!std::is_same<Context, CPUContext>::value) {
792  if (get_rgb_) {
793  prefetched_clip_rgb_on_device_.CopyFrom(prefetched_clip_rgb_, &context_);
794  }
795  if (get_optical_flow_) {
796  prefetched_clip_of_on_device_.CopyFrom(prefetched_clip_of_, &context_);
797  }
798  prefetched_label_on_device_.CopyFrom(prefetched_label_, &context_);
799  if (get_video_id_) {
800  prefetched_video_id_on_device_.CopyFrom(prefetched_video_id_, &context_);
801  }
802  }
803  return true;
804 }
805 
806 template <class Context>
808  int index = 0;
809  if (get_rgb_) {
810  auto* clip_rgb_output = OperatorBase::Output<Tensor<Context>>(index++);
811  if (std::is_same<Context, CPUContext>::value) {
812  clip_rgb_output->CopyFrom(prefetched_clip_rgb_, &context_);
813  } else {
814  clip_rgb_output->CopyFrom(prefetched_clip_rgb_on_device_, &context_);
815  }
816  }
817  if (get_optical_flow_) {
818  auto* clip_of_output = OperatorBase::Output<Tensor<Context>>(index++);
819  if (std::is_same<Context, CPUContext>::value) {
820  clip_of_output->CopyFrom(prefetched_clip_of_, &context_);
821  } else {
822  clip_of_output->CopyFrom(prefetched_clip_of_on_device_, &context_);
823  }
824  }
825  auto* label_output = OperatorBase::Output<Tensor<Context>>(index++);
826  if (std::is_same<Context, CPUContext>::value) {
827  label_output->CopyFrom(prefetched_label_, &context_);
828  } else {
829  label_output->CopyFrom(prefetched_label_on_device_, &context_);
830  }
831  if (get_video_id_) {
832  auto* video_id_output = OperatorBase::Output<Tensor<Context>>(index);
833  if (std::is_same<Context, CPUContext>::value) {
834  video_id_output->CopyFrom(prefetched_video_id_, &context_);
835  } else {
836  video_id_output->CopyFrom(prefetched_video_id_on_device_, &context_);
837  }
838  }
839  return true;
840 }
841 
842 } // namespace caffe2
843 
844 #endif // CAFFE2_VIDEO_VIDEO_INPUT_OP_H_
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:222
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:66
T * mutable_data()
Returns a typed pointer of the underlying storage.
Definition: tensor.h:578
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...