Caffe2 - C++ API
A deep learning, cross platform ML framework
im2col_op.h
1 #ifndef CAFFE2_OPERATORS_IM2COL_OP_H_
2 #define CAFFE2_OPERATORS_IM2COL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class Im2ColOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  Im2ColOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  pad_(OperatorBase::GetSingleArgument<int>("pad", 0)),
18  kernel_h_(OperatorBase::GetSingleArgument<int>(
19  "kernel_h",
20  OperatorBase::GetSingleArgument<int>("kernel", 0))),
21  kernel_w_(OperatorBase::GetSingleArgument<int>(
22  "kernel_w",
23  OperatorBase::GetSingleArgument<int>("kernel", 0))),
24  dilation_h_(OperatorBase::GetSingleArgument<int>(
25  "dilation_h",
26  OperatorBase::GetSingleArgument<int>("dilation", 1))),
27  dilation_w_(OperatorBase::GetSingleArgument<int>(
28  "dilation_w",
29  OperatorBase::GetSingleArgument<int>("dilation", 1))),
30  stride_h_(OperatorBase::GetSingleArgument<int>(
31  "stride_h",
32  OperatorBase::GetSingleArgument<int>("stride", 1))),
33  stride_w_(OperatorBase::GetSingleArgument<int>(
34  "stride_w",
35  OperatorBase::GetSingleArgument<int>("stride", 1))),
36  order_(StringToStorageOrder(
37  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
38  CAFFE_ENFORCE(kernel_h_ > 0);
39  CAFFE_ENFORCE(kernel_w_ > 0);
40  CAFFE_ENFORCE(dilation_h_ > 0);
41  CAFFE_ENFORCE(dilation_w_ > 0);
42  CAFFE_ENFORCE(stride_h_ > 0);
43  CAFFE_ENFORCE(stride_w_ > 0);
44  CAFFE_ENFORCE(pad_ >= 0);
45  }
46 
47  bool RunOnDevice() override {
48  auto& X = Input(0);
49  auto* Y = Output(0);
50  CAFFE_ENFORCE(4 == X.ndim());
51 
52  int N = 0, C = 0, H = 0, W = 0;
53  switch (order_) {
54  case StorageOrder::NCHW:
55  N = X.dim32(0);
56  C = X.dim32(1);
57  H = X.dim32(2);
58  W = X.dim32(3);
59  break;
60  case StorageOrder::NHWC:
61  N = X.dim32(0);
62  H = X.dim32(1);
63  W = X.dim32(2);
64  C = X.dim32(3);
65  break;
66  default:
67  CAFFE_THROW("Unknown storage order: ", order_);
68  }
69 
70  const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
71  const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
72  CAFFE_ENFORCE(H >= dkernel_h);
73  CAFFE_ENFORCE(W >= dkernel_w);
74  const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
75  const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
76 
77  switch (order_) {
78  case StorageOrder::NCHW: {
79  Y->Resize(
80  std::vector<TIndex>{N, C * kernel_h_ * kernel_w_, out_h, out_w});
81 
82  const size_t dx = X.size() / N;
83  const size_t dy = Y->size() / N;
84  for (int n = 0; n < N; ++n) {
85  const auto* xdata = X.template data<T>() + (n * dx);
86  auto* ydata = Y->template mutable_data<T>() + (n * dy);
87  math::Im2col<T, Context, StorageOrder::NCHW>(
88  xdata,
89  C,
90  H,
91  W,
92  kernel_h_,
93  kernel_w_,
94  dilation_h_,
95  dilation_w_,
96  pad_,
97  pad_,
98  pad_,
99  pad_,
100  stride_h_,
101  stride_w_,
102  ydata,
103  &context_);
104  }
105  }; break;
106  case StorageOrder::NHWC: {
107  Y->Resize(
108  std::vector<TIndex>{N, out_h, out_w, kernel_h_ * kernel_w_ * C});
109 
110  const size_t dx = X.size() / N;
111  const size_t dy = Y->size() / N;
112  for (int n = 0; n < N; ++n) {
113  const auto* xdata = X.template data<T>() + (n * dx);
114  auto* ydata = Y->template mutable_data<T>() + (n * dy);
115  math::Im2col<T, Context, StorageOrder::NHWC>(
116  xdata,
117  C,
118  H,
119  W,
120  kernel_h_,
121  kernel_w_,
122  dilation_h_,
123  dilation_w_,
124  pad_,
125  pad_,
126  pad_,
127  pad_,
128  stride_h_,
129  stride_w_,
130  ydata,
131  &context_);
132  }
133  }; break;
134  default:
135  CAFFE_THROW("Unknown storage order: ", order_);
136  }
137 
138  return true;
139  }
140 
141  private:
142  int pad_;
143  int kernel_h_;
144  int kernel_w_;
145  int dilation_h_;
146  int dilation_w_;
147  int stride_h_;
148  int stride_w_;
149  StorageOrder order_;
150 };
151 
152 template <typename T, class Context>
153 class Col2ImOp final : public Operator<Context> {
154  public:
155  USE_OPERATOR_CONTEXT_FUNCTIONS;
156  Col2ImOp(const OperatorDef& operator_def, Workspace* ws)
157  : Operator<Context>(operator_def, ws),
158  pad_(OperatorBase::GetSingleArgument<int>("pad", 0)),
159  kernel_h_(OperatorBase::GetSingleArgument<int>(
160  "kernel_h",
161  OperatorBase::GetSingleArgument<int>("kernel", 0))),
162  kernel_w_(OperatorBase::GetSingleArgument<int>(
163  "kernel_w",
164  OperatorBase::GetSingleArgument<int>("kernel", 0))),
165  dilation_h_(OperatorBase::GetSingleArgument<int>(
166  "dilation_h",
167  OperatorBase::GetSingleArgument<int>("dilation", 1))),
168  dilation_w_(OperatorBase::GetSingleArgument<int>(
169  "dilation_w",
170  OperatorBase::GetSingleArgument<int>("dilation", 1))),
171  stride_h_(OperatorBase::GetSingleArgument<int>(
172  "stride_h",
173  OperatorBase::GetSingleArgument<int>("stride", 1))),
174  stride_w_(OperatorBase::GetSingleArgument<int>(
175  "stride_w",
176  OperatorBase::GetSingleArgument<int>("stride", 1))),
177  order_(StringToStorageOrder(
178  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
179  CAFFE_ENFORCE(kernel_h_ > 0);
180  CAFFE_ENFORCE(kernel_w_ > 0);
181  CAFFE_ENFORCE(dilation_h_ > 0);
182  CAFFE_ENFORCE(dilation_w_ > 0);
183  CAFFE_ENFORCE(stride_h_ > 0);
184  CAFFE_ENFORCE(stride_w_ > 0);
185  CAFFE_ENFORCE(pad_ >= 0);
186  }
187 
188  bool RunOnDevice() override {
189  auto& X = Input(0);
190  auto& Z = Input(1);
191  auto* Y = Output(0);
192  Y->ResizeLike(Z);
193  CAFFE_ENFORCE(4 == Y->ndim());
194 
195  int N = 0, C = 0, H = 0, W = 0;
196  switch (order_) {
197  case StorageOrder::NCHW:
198  N = Y->dim32(0);
199  C = Y->dim32(1);
200  H = Y->dim32(2);
201  W = Y->dim32(3);
202  break;
203  case StorageOrder::NHWC:
204  N = Y->dim32(0);
205  H = Y->dim32(1);
206  W = Y->dim32(2);
207  C = Y->dim32(3);
208  break;
209  default:
210  CAFFE_THROW("Unknown storage order: ", order_);
211  }
212 
213  const int dkernel_h = dilation_h_ * (kernel_h_ - 1) + 1;
214  const int dkernel_w = dilation_w_ * (kernel_w_ - 1) + 1;
215  CAFFE_ENFORCE(H >= dkernel_h);
216  CAFFE_ENFORCE(W >= dkernel_w);
217  const int out_h = (H + 2 * pad_ - dkernel_h) / stride_h_ + 1;
218  const int out_w = (W + 2 * pad_ - dkernel_w) / stride_w_ + 1;
219  CAFFE_ENFORCE(X.size() == N * kernel_h_ * kernel_w_ * C * out_h * out_w);
220 
221  const size_t dx = X.size() / N;
222  const size_t dy = Y->size() / N;
223 
224  // could template-specialize this, but it's test code...
225  switch (order_) {
226  case StorageOrder::NCHW: {
227  for (int n = 0; n < N; ++n) {
228  const auto* xdata = X.template data<T>() + (n * dx);
229  auto* ydata = Y->template mutable_data<T>() + (n * dy);
230  math::Col2im<T, Context, StorageOrder::NCHW>(
231  xdata,
232  C,
233  H,
234  W,
235  kernel_h_,
236  kernel_w_,
237  dilation_h_,
238  dilation_w_,
239  pad_,
240  pad_,
241  pad_,
242  pad_,
243  stride_h_,
244  stride_w_,
245  ydata,
246  &context_);
247  }
248  }; break;
249  case StorageOrder::NHWC: {
250  for (int n = 0; n < N; ++n) {
251  const auto* xdata = X.template data<T>() + (n * dx);
252  auto* ydata = Y->template mutable_data<T>() + (n * dy);
253  math::Col2im<T, Context, StorageOrder::NHWC>(
254  xdata,
255  C,
256  H,
257  W,
258  kernel_h_,
259  kernel_w_,
260  dilation_h_,
261  dilation_w_,
262  pad_,
263  pad_,
264  pad_,
265  pad_,
266  stride_h_,
267  stride_w_,
268  ydata,
269  &context_);
270  }
271  }; break;
272  default:
273  CAFFE_THROW("Unknown storage order: ", order_);
274  }
275 
276  return true;
277  }
278 
279  private:
280  int pad_;
281  int kernel_h_;
282  int kernel_w_;
283  int dilation_h_;
284  int dilation_w_;
285  int stride_h_;
286  int stride_w_;
287  StorageOrder order_;
288 };
289 
290 } // namespace caffe2
291 
292 #endif // CAFFE2_OPERATORS_IM2COL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...