Caffe2 - C++ API
A deep learning, cross platform ML framework
optical_flow.cc
1 #include <caffe2/video/optical_flow.h>
2 
3 namespace caffe2 {
4 
5 void OpticalFlowExtractor(
6  const cv::Mat& prev_gray,
7  const cv::Mat& curr_gray,
8  const int flow_alg_type,
9  cv::Mat& flow) {
10  cv::Ptr<cv::DualTVL1OpticalFlow> tvl1 = cv::DualTVL1OpticalFlow::create();
11  switch (flow_alg_type) {
12  case FLowAlgType::FarnebackOpticalFlow:
13  cv::calcOpticalFlowFarneback(
14  prev_gray,
15  curr_gray,
16  flow,
17  std::sqrt(2) / 2.0,
18  5,
19  10,
20  2,
21  7,
22  1.5,
23  cv::OPTFLOW_FARNEBACK_GAUSSIAN);
24  break;
25  case FLowAlgType::DensePyrLKOpticalFlow:
26  LOG(ERROR) << "DensePyrLKOpticalFlow only has sparse version on CPU";
27  break;
28  case FLowAlgType::BroxOpticalFlow:
29  LOG(ERROR) << "BroxOpticalFlow on CPU is not available";
30  break;
31  case FLowAlgType::OpticalFlowDual_TVL1:
32  tvl1->calc(prev_gray, curr_gray, flow);
33  break;
34  default:
35  LOG(ERROR) << "Unsupported optical flow type " << flow_alg_type;
36  break;
37  }
38 }
39 
40 void MergeOpticalFlow(cv::Mat& prev_flow, const cv::Mat& curr_flow) {
41  const int rows = prev_flow.rows;
42  const int cols = prev_flow.cols;
43 
44  // merge two optical flows into one
45  for (int y = 0; y < rows; y++) {
46  for (int x = 0; x < cols; x++) {
47  cv::Point2f u = prev_flow.at<cv::Point2f>(y, x);
48  // get the new location
49  int x_new = std::min(cols - 1, std::max(0, cvRound(u.x + x)));
50  int y_new = std::min(rows - 1, std::max(0, cvRound(u.y + y)));
51  cv::Point2f u_new = curr_flow.at<cv::Point2f>(y_new, x_new);
52 
53  // update the flow
54  prev_flow.at<cv::Point2f>(y, x) += u_new;
55  }
56  }
57 }
58 
59 void MultiFrameOpticalFlowExtractor(
60  const std::vector<cv::Mat>& grays,
61  const int optical_flow_alg_type,
62  cv::Mat& flow) {
63  int num_frames = grays.size();
64  CAFFE_ENFORCE_GE(num_frames, 2, "need at least 2 frames!");
65 
66  // compute optical flow for every two frames
67  std::vector<cv::Mat> flows;
68  for (int i = 0; i < num_frames - 1; i++) {
69  cv::Mat tmp;
70  OpticalFlowExtractor(grays[i], grays[i + 1], optical_flow_alg_type, tmp);
71  flows.push_back(tmp);
72  }
73 
74  flows[0].copyTo(flow);
75  // aggregate optical flow across multiple frame
76  for (int i = 1; i < num_frames - 1; i++) {
77  MergeOpticalFlow(flow, flows[i]);
78  }
79 }
80 
81 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...