Caffe2 - C++ API
A deep learning, cross platform ML framework
locally_connected_op_impl.h
1 // locally_connected_impl.h is the templated implementation of the
2 // locally_connected.h file.
3 
4 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
5 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
6 
7 #include <vector>
8 
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/flags.h"
11 #include "caffe2/core/logging.h"
12 #include "caffe2/core/operator.h"
13 #include "caffe2/operators/conv_pool_op_base.h"
14 #include "caffe2/operators/locally_connected_op.h"
15 #include "caffe2/utils/math.h"
16 
17 namespace caffe2 {
18 
19 template <typename T, class Context>
20 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
21  const auto& X = Input(INPUT);
22  const auto& filter = Input(FILTER);
23  auto* Y = Output(0);
24  const int image_ndim = X.ndim() - 2;
25  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
26  lc_op_util::ShapeParams shape;
27  shape.N = X.dim32(0);
28  shape.C = X.dim32(1);
29  shape.M = filter.dim32(image_ndim);
30  CAFFE_ENFORCE(
31  shape.C == filter.dim32(image_ndim + 1) * group_,
32  "Locally Connected op: input channels does not match: "
33  "# of input channels ",
34  shape.C,
35  " is not equal to kernel channels * group:",
36  filter.dim32(image_ndim + 1),
37  "*",
38  group_);
39  CAFFE_ENFORCE(
40  shape.M % group_ == 0,
41  "The number of output channels is not divisible by group.");
42 
43  ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
44  shape.input_image_size = GetDimsSize(X);
45  shape.output_image_size = GetDimsSize(*Y);
46  const std::vector<int> output_image_dims = GetDims(*Y);
47  for (int i = 0; i < image_ndim; ++i) {
48  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
49  }
50 
51  int kernel_dims_size = 1;
52  for (std::size_t i = 0; i < kernel_.size(); ++i) {
53  CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
54  kernel_dims_size *= kernel_[i];
55  }
56 
57  shape.input_image_dims = GetDims(X);
58  const std::vector<int> X_dims(X.dims().cbegin() + 1, X.dims().cend());
59  SetDeviceTensor(X_dims, &X_dims_device_);
60  shape.kernel_size = shape.C / group_ * kernel_dims_size;
61 
62  SetColumnBufferShape(
63  shape.N,
64  shape.C,
65  shape.kernel_size,
66  shape.output_image_size,
67  &shape.column_dims,
68  &shape.column_transposed_dims);
69  SetYTranposedBufferShape(
70  shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
71 
72  const T* X_data = X.template data<T>();
73  const T* filter_data = filter.template data<T>();
74  const T* bias_data = nullptr;
75  if (InputSize() == 3) {
76  const auto& bias = Input(BIAS);
77  CAFFE_ENFORCE(bias.ndim() == image_ndim + 1);
78  for (int i = 0; i < image_ndim; ++i) {
79  CAFFE_ENFORCE(bias.dim32(i) == output_image_dims[i]);
80  }
81  CAFFE_ENFORCE(bias.dim32(image_ndim) == shape.M);
82  bias_data = bias.template data<T>();
83  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
84  shape.N, &bias_multiplier_);
85  }
86  T* Y_data = Y->template mutable_data<T>();
87 
88  RunOnDeviceWithOrderNCHWImpl(
89  shape,
90  X_data,
91  filter_data,
92  bias_data,
93  Y_data,
94  &column_buffer_,
95  &column_transposed_buffer_,
96  &Y_transposed_buffer_);
97 
98  return true;
99 }
100 
101 template <typename T, class Context>
102 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
103  const auto& X = Input(INPUT);
104  const auto& filter = Input(FILTER);
105  auto* Y = Output(0);
106  CAFFE_ENFORCE_EQ(
107  kernel_.size(),
108  2,
109  "Only 2d locally connected op is supported for NHWC storage type.");
110  const int image_ndim = X.ndim() - 2;
111  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
112  lc_op_util::ShapeParams shape;
113  shape.N = X.dim32(0);
114  shape.C = X.dim32(3);
115  shape.input_image_dims = {X.dim32(1), X.dim32(2)};
116  shape.M = filter.dim32(image_ndim);
117  CAFFE_ENFORCE(filter.dim32(image_ndim + 1) == kernel_h());
118  CAFFE_ENFORCE(filter.dim32(image_ndim + 2) == kernel_w());
119  CAFFE_ENFORCE(filter.dim32(image_ndim + 3) == shape.C);
120  ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
121 
122  shape.input_image_size = GetDimsSize(X);
123  shape.output_image_size = GetDimsSize(*Y);
124  const std::vector<int> output_image_dims = GetDims(*Y);
125  for (int i = 0; i < image_ndim; ++i) {
126  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
127  }
128 
129  shape.kernel_size = kernel_h() * kernel_w() * shape.C;
130  SetColumnBufferShape(
131  shape.N,
132  shape.C,
133  shape.kernel_size,
134  shape.output_image_size,
135  &shape.column_dims,
136  &shape.column_transposed_dims);
137  SetYTranposedBufferShape(
138  shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
139 
140  const T* X_data = X.template data<T>();
141  const T* filter_data = filter.template data<T>();
142  const T* bias_data = nullptr;
143  if (InputSize() == 3) {
144  const auto& bias = Input(BIAS);
145  CAFFE_ENFORCE(bias.ndim() == image_ndim + 1);
146  for (int i = 0; i < image_ndim; ++i) {
147  CAFFE_ENFORCE(bias.dim32(i) == output_image_dims[i]);
148  }
149  CAFFE_ENFORCE(bias.dim32(image_ndim) == shape.M);
150  bias_data = bias.template data<T>();
151  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
152  shape.N, &bias_multiplier_);
153  }
154  T* Y_data = Y->template mutable_data<T>();
155 
156  RunOnDeviceWithOrderNHWCImpl(
157  shape,
158  X_data,
159  filter_data,
160  bias_data,
161  Y_data,
162  &column_buffer_,
163  &column_transposed_buffer_,
164  &Y_transposed_buffer_);
165 
166  return true;
167 }
168 
169 template <typename T, class Context>
170 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
171  const lc_op_util::ShapeParams& shape,
172  const T* X_data,
173  const T* filter_data,
174  const T* bias_data,
175  T* Y_data,
176  Tensor<Context>* column_buffer,
177  Tensor<Context>* column_transposed_buffer,
178  Tensor<Context>* Y_transposed_buffer) {
179  const int input_stride = shape.C / group_ * shape.input_image_size;
180  const int column_stride = shape.kernel_size * shape.output_image_size;
181  column_buffer->Resize(shape.column_dims);
182  column_transposed_buffer->Resize(shape.column_transposed_dims);
183  Y_transposed_buffer->Resize(shape.Y_transposed_dims);
184  T* column_buffer_data = column_buffer->template mutable_data<T>();
185  T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
186 
187  for (int image_id = 0; image_id < shape.N; ++image_id) {
188  for (int group_id = 0; group_id < group_; ++group_id) {
189  if (kernel_.size() == 2) {
190  math::Im2col<T, Context, StorageOrder::NCHW>(
191  X_data + group_id * input_stride,
192  shape.C / group_,
193  shape.input_image_dims[0],
194  shape.input_image_dims[1],
195  kernel_h(),
196  kernel_w(),
197  dilation_h(),
198  dilation_w(),
199  pad_t(),
200  pad_l(),
201  pad_b(),
202  pad_r(),
203  stride_h(),
204  stride_w(),
205  column_buffer_data + group_id * column_stride,
206  &context_);
207  } else {
208  math::Im2colNd<T, Context, StorageOrder::NCHW>(
209  X_data + group_id * input_stride,
210  X_dims_device_.template data<int>(),
211  column_dims_device_.template data<int>() + 1,
212  shape.C * shape.input_image_size,
213  column_stride,
214  kernel_device_.template data<int>(),
215  stride_device_.template data<int>(),
216  dilation_device_.template data<int>(),
217  pads_device_.template data<int>(),
218  kernel_.size(),
219  column_buffer_data + group_id * column_stride,
220  &context_);
221  }
222  }
223  X_data += input_stride * group_;
224  column_buffer_data += column_stride * group_;
225  }
226  math::Transpose(
227  shape.column_dims.size(),
228  column_dims_device_.template data<int>(),
229  column_transposed_dims_device_.template data<int>(),
230  column_axes_device_.template data<int>(),
231  column_buffer->size(),
232  column_buffer->template data<T>(),
233  column_transposed_buffer->template mutable_data<T>(),
234  &context_);
235  math::GemmBatched(
236  CblasNoTrans,
237  CblasNoTrans,
238  shape.output_image_size * group_,
239  shape.M / group_,
240  shape.N,
241  shape.kernel_size,
242  1.0f,
243  filter_data,
244  column_transposed_buffer->template data<T>(),
245  0.0f,
246  Y_transposed_buffer_data,
247  &context_);
248  if (bias_data != nullptr) {
249  math::Gemm<T, Context>(
250  CblasNoTrans,
251  CblasNoTrans,
252  shape.output_image_size * shape.M,
253  shape.N,
254  1,
255  1.0,
256  bias_data,
257  bias_multiplier_.template data<T>(),
258  1.0,
259  Y_transposed_buffer_data,
260  &context_);
261  }
262  math::Transpose(
263  shape.Y_transposed_dims.size(),
264  Y_transposed_dims_device_.template data<int>(),
265  Y_dims_device_.template data<int>(),
266  Y_transposed_axes_device_.template data<int>(),
267  Y_transposed_buffer->size(),
268  Y_transposed_buffer_data,
269  Y_data,
270  &context_);
271 }
272 
273 template <typename T, class Context>
274 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
275  const lc_op_util::ShapeParams& shape,
276  const T* X_data,
277  const T* filter_data,
278  const T* bias_data,
279  T* Y_data,
280  Tensor<Context>* column_buffer,
281  Tensor<Context>* column_transposed_buffer,
282  Tensor<Context>* Y_transposed_buffer) {
283  const int input_stride = shape.C * shape.input_image_size;
284  const int column_stride = shape.kernel_size * shape.output_image_size;
285  column_buffer->Resize(shape.column_dims);
286  column_transposed_buffer->Resize(shape.column_transposed_dims);
287  Y_transposed_buffer->Resize(shape.Y_transposed_dims);
288  T* column_buffer_data = column_buffer->template mutable_data<T>();
289  T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
290  for (int image_id = 0; image_id < shape.N; ++image_id) {
291  math::Im2col<T, Context, StorageOrder::NHWC>(
292  X_data + image_id * input_stride,
293  shape.C,
294  shape.input_image_dims[0],
295  shape.input_image_dims[1],
296  kernel_h(),
297  kernel_w(),
298  dilation_h(),
299  dilation_w(),
300  pad_t(),
301  pad_l(),
302  pad_b(),
303  pad_r(),
304  stride_h(),
305  stride_w(),
306  column_buffer_data + image_id * column_stride,
307  &context_);
308  }
309  math::Transpose(
310  shape.column_dims.size(),
311  column_dims_device_.template data<int>(),
312  column_transposed_dims_device_.template data<int>(),
313  column_axes_device_.template data<int>(),
314  column_buffer->size(),
315  column_buffer->template data<T>(),
316  column_transposed_buffer->template mutable_data<T>(),
317  &context_);
318  math::GemmBatched(
319  CblasNoTrans,
320  CblasTrans,
321  shape.output_image_size,
322  shape.N,
323  shape.M,
324  shape.kernel_size,
325  1.0f,
326  column_transposed_buffer->template data<T>(),
327  filter_data,
328  0.0f,
329  Y_transposed_buffer_data,
330  &context_);
331  math::Transpose(
332  shape.Y_transposed_dims.size(),
333  Y_transposed_dims_device_.template data<int>(),
334  Y_dims_device_.template data<int>(),
335  Y_transposed_axes_device_.template data<int>(),
336  Y_transposed_buffer->size(),
337  Y_transposed_buffer_data,
338  Y_data,
339  &context_);
340  if (bias_data != nullptr) {
341  math::Gemm<T, Context>(
342  CblasNoTrans,
343  CblasNoTrans,
344  shape.N,
345  shape.output_image_size * shape.M,
346  1,
347  1.0f,
348  bias_multiplier_.template data<T>(),
349  bias_data,
350  1.0f,
351  Y_data,
352  &context_);
353  }
354 }
355 
356 template <typename T, class Context>
357 void LocallyConnectedOp<T, Context>::SetColumnBufferShape(
358  const int N,
359  const int C,
360  const int kernel_size,
361  const int output_image_size,
362  std::vector<int>* column_dims,
363  std::vector<int>* column_transposed_dims) {
364  std::vector<int> column_axes;
365  lc_op_util::SetColumnBufferShapeImpl(
366  N,
367  kernel_size,
368  output_image_size,
369  order_,
370  column_dims,
371  column_transposed_dims,
372  &column_axes,
373  nullptr);
374  SetDeviceTensor(*column_dims, &column_dims_device_);
375  SetDeviceTensor(*column_transposed_dims, &column_transposed_dims_device_);
376  SetDeviceTensor(column_axes, &column_axes_device_);
377 }
378 
379 template <typename T, class Context>
380 void LocallyConnectedOp<T, Context>::SetYTranposedBufferShape(
381  const int N,
382  const int M,
383  const int output_image_size,
384  std::vector<int>* Y_transposed_dims) {
385  std::vector<int> Y_dims;
386  std::vector<int> Y_transposed_axes;
387  lc_op_util::SetYBufferShapeImpl(
388  N,
389  M,
390  output_image_size,
391  order_,
392  &Y_dims,
393  Y_transposed_dims,
394  nullptr,
395  &Y_transposed_axes);
396  SetDeviceTensor(Y_dims, &Y_dims_device_);
397  SetDeviceTensor(*Y_transposed_dims, &Y_transposed_dims_device_);
398  SetDeviceTensor(Y_transposed_axes, &Y_transposed_axes_device_);
399 }
400 
401 template <typename T, class Context>
402 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
403  const auto& X = Input(INPUT);
404  const auto& filter = Input(FILTER);
405  const auto& dY = Input(OUTPUT_GRAD);
406  auto* dfilter = Output(FILTER_GRAD);
407  const int image_ndim = X.ndim() - 2;
408  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
409 
410  lc_op_util::ShapeParams shape;
411  shape.N = X.dim32(0);
412  shape.C = X.dim32(1);
413  shape.M = filter.dim32(image_ndim);
414  CAFFE_ENFORCE(filter.dim32(image_ndim + 1) * group_ == shape.C);
415  CAFFE_ENFORCE(shape.M % group_ == 0);
416 
417  shape.input_image_dims = GetDims(X);
418  shape.input_image_size = GetDimsSize(X);
419  const std::vector<int> output_image_dims = GetDims(dY);
420  shape.output_image_size = GetDimsSize(dY);
421  for (int i = 0; i < image_ndim; ++i) {
422  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
423  }
424  ConvPoolOpBase<Context>::ComputePads(shape.input_image_dims);
425 
426  int kernel_dims_size = 1;
427  for (std::size_t i = 0; i < kernel_.size(); ++i) {
428  CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
429  kernel_dims_size *= kernel_[i];
430  }
431 
432  const std::vector<int> X_dims(X.dims().cbegin() + 1, X.dims().cend());
433  SetDeviceTensor(X_dims, &X_dims_device_);
434  shape.kernel_size = shape.C / group_ * kernel_dims_size;
435 
436  SetColumnBufferShape(
437  shape.N,
438  shape.C,
439  shape.kernel_size,
440  shape.output_image_size,
441  &shape.column_dims,
442  &shape.column_transposed_dims);
443  SetDYTranposedBufferShape(
444  shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
445 
446  dfilter->ResizeLike(filter);
447  const T* X_data = X.template data<T>();
448  const T* filter_data = filter.template data<T>();
449  const T* dY_data = dY.template data<T>();
450  T* dfilter_data = dfilter->template mutable_data<T>();
451  T* dX_data = nullptr;
452  T* dbias_data = nullptr;
453  if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
454  auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
455  dX->ResizeLike(X);
456  dX_data = dX->template mutable_data<T>();
457  }
458  if (!no_bias_) {
459  auto* dbias = Output(BIAS_OR_INPUT_GRAD);
460  std::vector<int> dbias_dims = output_image_dims;
461  dbias_dims.push_back(shape.M);
462  dbias->Resize(dbias_dims);
463  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
464  shape.N, &bias_multiplier_);
465  dbias_data = dbias->template mutable_data<T>();
466  }
467  RunOnDeviceWithOrderNCHWImpl(
468  shape,
469  X_data,
470  filter_data,
471  dY_data,
472  dfilter_data,
473  dX_data,
474  dbias_data,
475  &column_buffer_,
476  &column_transposed_buffer_,
477  &dY_transposed_buffer_);
478 
479  return true;
480 }
481 
482 template <typename T, class Context>
483 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
484  const auto& X = Input(INPUT);
485  const auto& filter = Input(FILTER);
486  const auto& dY = Input(OUTPUT_GRAD);
487  auto* dfilter = Output(FILTER_GRAD);
488  CAFFE_ENFORCE_EQ(
489  kernel_.size(),
490  2,
491  "Only 2d locally connected op is supported for NHWC storage type.");
492  const int image_ndim = X.ndim() - 2;
493  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
494  lc_op_util::ShapeParams shape;
495  shape.N = X.dim32(0);
496  shape.C = X.dim32(3);
497  shape.input_image_dims = {X.dim32(1), X.dim32(2)};
498  shape.M = filter.dim32(image_ndim);
499  CAFFE_ENFORCE(filter.dim32(image_ndim + 1) == kernel_h());
500  CAFFE_ENFORCE(filter.dim32(image_ndim + 2) == kernel_w());
501  CAFFE_ENFORCE(filter.dim32(image_ndim + 3) == shape.C);
502  ConvPoolOpBase<Context>::ComputePads(shape.input_image_dims);
503 
504  shape.input_image_size = GetDimsSize(X);
505  shape.output_image_size = GetDimsSize(dY);
506  const std::vector<int> output_image_dims = GetDims(dY);
507  for (int i = 0; i < image_ndim; ++i) {
508  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
509  }
510 
511  shape.kernel_size = kernel_h() * kernel_w() * shape.C;
512  SetColumnBufferShape(
513  shape.N,
514  shape.C,
515  shape.kernel_size,
516  shape.output_image_size,
517  &shape.column_dims,
518  &shape.column_transposed_dims);
519  SetDYTranposedBufferShape(
520  shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
521 
522  dfilter->ResizeLike(filter);
523  const T* X_data = X.template data<T>();
524  const T* filter_data = filter.template data<T>();
525  const T* dY_data = dY.template data<T>();
526  T* dfilter_data = dfilter->template mutable_data<T>();
527  T* dX_data = nullptr;
528  T* dbias_data = nullptr;
529  if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
530  auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
531  dX->ResizeLike(X);
532  dX_data = dX->template mutable_data<T>();
533  }
534  if (!no_bias_) {
535  auto* dbias = Output(BIAS_OR_INPUT_GRAD);
536  std::vector<int> dbias_dims = output_image_dims;
537  dbias_dims.push_back(shape.M);
538  dbias->Resize(dbias_dims);
539  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
540  shape.N, &bias_multiplier_);
541  dbias_data = dbias->template mutable_data<T>();
542  }
543  RunOnDeviceWithOrderNHWCImpl(
544  shape,
545  X_data,
546  filter_data,
547  dY_data,
548  dfilter_data,
549  dX_data,
550  dbias_data,
551  &column_buffer_,
552  &column_transposed_buffer_,
553  &dY_transposed_buffer_);
554 
555  return true;
556 }
557 
558 template <typename T, class Context>
559 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
560  const lc_op_util::ShapeParams& shape,
561  const T* X_data,
562  const T* filter_data,
563  const T* dY_data,
564  T* dfilter_data,
565  T* dX_data,
566  T* dbias_data,
567  Tensor<Context>* column_buffer,
568  Tensor<Context>* column_transposed_buffer,
569  Tensor<Context>* dY_transposed_buffer) {
570  const int input_stride = shape.C * shape.input_image_size;
571  const int column_stride = shape.kernel_size * shape.output_image_size;
572  column_buffer->Resize(shape.column_dims);
573  column_transposed_buffer->Resize(shape.column_transposed_dims);
574  dY_transposed_buffer->Resize(shape.Y_transposed_dims);
575  T* column_buffer_data = column_buffer->template mutable_data<T>();
576  T* dY_transposed_buffer_data =
577  dY_transposed_buffer->template mutable_data<T>();
578 
579  for (int image_id = 0; image_id < shape.N; ++image_id) {
580  for (int group_id = 0; group_id < group_; ++group_id) {
581  if (kernel_.size() == 2) {
582  math::Im2col<T, Context, StorageOrder::NCHW>(
583  X_data + group_id * input_stride,
584  shape.C / group_,
585  shape.input_image_dims[0],
586  shape.input_image_dims[1],
587  kernel_h(),
588  kernel_w(),
589  dilation_h(),
590  dilation_w(),
591  pad_t(),
592  pad_l(),
593  pad_b(),
594  pad_r(),
595  stride_h(),
596  stride_w(),
597  column_buffer_data + group_id * column_stride,
598  &context_);
599  } else {
600  math::Im2colNd<T, Context, StorageOrder::NCHW>(
601  X_data + group_id * input_stride,
602  X_dims_device_.template data<int>(),
603  column_dims_device_.template data<int>() + 1,
604  shape.C * shape.input_image_size,
605  column_stride,
606  kernel_device_.template data<int>(),
607  stride_device_.template data<int>(),
608  dilation_device_.template data<int>(),
609  pads_device_.template data<int>(),
610  kernel_.size(),
611  column_buffer_data + group_id * column_stride,
612  &context_);
613  }
614  }
615  X_data += input_stride * group_;
616  column_buffer_data += column_stride * group_;
617  }
618  math::Transpose(
619  shape.column_dims.size(),
620  column_dims_device_.template data<int>(),
621  column_transposed_dims_device_.template data<int>(),
622  column_axes_device_.template data<int>(),
623  column_buffer->size(),
624  column_buffer->template data<T>(),
625  column_transposed_buffer->template mutable_data<T>(),
626  &context_);
627 
628  math::Transpose(
629  shape.Y_transposed_dims.size(),
630  dY_dims_device_.template data<int>(),
631  dY_transposed_dims_device_.template data<int>(),
632  dY_axes_device_.template data<int>(),
633  dY_transposed_buffer->size(),
634  dY_data,
635  dY_transposed_buffer_data,
636  &context_);
637 
638  // Gradient respect to filter.
639  math::GemmBatched(
640  CblasNoTrans,
641  CblasTrans,
642  shape.output_image_size * group_,
643  shape.M / group_,
644  shape.kernel_size,
645  shape.N,
646  1.0f,
647  dY_transposed_buffer_data,
648  column_transposed_buffer->template data<T>(),
649  0.0f,
650  dfilter_data,
651  &context_);
652 
653  if (dbias_data != nullptr) {
654  // Gradient respect to bias.
655  math::Gemv<T, Context>(
656  CblasNoTrans,
657  shape.output_image_size * shape.M,
658  shape.N,
659  1.0f,
660  dY_transposed_buffer_data,
661  bias_multiplier_.template data<T>(),
662  0.0f,
663  dbias_data,
664  &context_);
665  }
666 
667  if (dX_data != nullptr) {
668  // Gradient respect to X.
669  math::GemmBatched(
670  CblasTrans,
671  CblasNoTrans,
672  shape.output_image_size * group_,
673  shape.kernel_size,
674  shape.N,
675  shape.M / group_,
676  1.0f,
677  filter_data,
678  dY_transposed_buffer_data,
679  0.0f,
680  column_transposed_buffer->template mutable_data<T>(),
681  &context_);
682  math::Transpose(
683  shape.column_dims.size(),
684  column_transposed_dims_device_.template data<int>(),
685  column_dims_device_.template data<int>(),
686  column_transposed_axes_device_.template data<int>(),
687  column_transposed_buffer->size(),
688  column_transposed_buffer->template data<T>(),
689  column_buffer->template mutable_data<T>(),
690  &context_);
691  const T* const_column_buffer_data = column_buffer->template data<T>();
692  for (int image_id = 0; image_id < shape.N; ++image_id) {
693  for (int group_id = 0; group_id < group_; ++group_id) {
694  if (kernel_.size() == 2) {
695  math::Col2im<T, Context, StorageOrder::NCHW>(
696  const_column_buffer_data + group_id * column_stride,
697  shape.C / group_,
698  shape.input_image_dims[0],
699  shape.input_image_dims[1],
700  kernel_h(),
701  kernel_w(),
702  dilation_h(),
703  dilation_w(),
704  pad_t(),
705  pad_l(),
706  pad_b(),
707  pad_r(),
708  stride_h(),
709  stride_w(),
710  dX_data + group_id * input_stride,
711  &context_);
712  } else {
713  math::Col2imNd<T, Context, StorageOrder::NCHW>(
714  const_column_buffer_data + group_id * column_stride,
715  X_dims_device_.template data<int>(),
716  column_dims_device_.template data<int>() + 1,
717  shape.C * shape.input_image_size,
718  column_stride,
719  kernel_device_.template data<int>(),
720  stride_device_.template data<int>(),
721  dilation_device_.template data<int>(),
722  pads_device_.template data<int>(),
723  kernel_.size(),
724  dX_data + group_id * input_stride,
725  &context_);
726  }
727  }
728  dX_data += input_stride * group_;
729  const_column_buffer_data += column_stride * group_;
730  }
731  }
732 }
733 
734 template <typename T, class Context>
735 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
736  const lc_op_util::ShapeParams& shape,
737  const T* X_data,
738  const T* filter_data,
739  const T* dY_data,
740  T* dfilter_data,
741  T* dX_data,
742  T* dbias_data,
743  Tensor<Context>* column_buffer,
744  Tensor<Context>* column_transposed_buffer,
745  Tensor<Context>* dY_transposed_buffer) {
746  const int input_stride = shape.C * shape.input_image_size;
747  const int column_stride = shape.kernel_size * shape.output_image_size;
748  column_buffer->Resize(shape.column_dims);
749  column_transposed_buffer->Resize(shape.column_transposed_dims);
750  dY_transposed_buffer->Resize(shape.Y_transposed_dims);
751  T* column_buffer_data = column_buffer->template mutable_data<T>();
752  T* dY_transposed_buffer_data =
753  dY_transposed_buffer->template mutable_data<T>();
754  for (int image_id = 0; image_id < shape.N; ++image_id) {
755  math::Im2col<T, Context, StorageOrder::NHWC>(
756  X_data + image_id * input_stride,
757  shape.C,
758  shape.input_image_dims[0],
759  shape.input_image_dims[1],
760  kernel_h(),
761  kernel_w(),
762  dilation_h(),
763  dilation_w(),
764  pad_t(),
765  pad_l(),
766  pad_b(),
767  pad_r(),
768  stride_h(),
769  stride_w(),
770  column_buffer_data + image_id * column_stride,
771  &context_);
772  }
773  math::Transpose(
774  shape.column_dims.size(),
775  column_dims_device_.template data<int>(),
776  column_transposed_dims_device_.template data<int>(),
777  column_axes_device_.template data<int>(),
778  column_buffer->size(),
779  column_buffer->template data<T>(),
780  column_transposed_buffer->template mutable_data<T>(),
781  &context_);
782 
783  math::Transpose(
784  shape.Y_transposed_dims.size(),
785  dY_dims_device_.template data<int>(),
786  dY_transposed_dims_device_.template data<int>(),
787  dY_axes_device_.template data<int>(),
788  dY_transposed_buffer->size(),
789  dY_data,
790  dY_transposed_buffer_data,
791  &context_);
792 
793  // Gradient respect to filter.
794  math::GemmBatched(
795  CblasTrans,
796  CblasNoTrans,
797  shape.output_image_size,
798  shape.M,
799  shape.kernel_size,
800  shape.N,
801  1.0f,
802  dY_transposed_buffer_data,
803  column_transposed_buffer->template data<T>(),
804  0.0f,
805  dfilter_data,
806  &context_);
807 
808  if (dbias_data != nullptr) {
809  // Gradient respect to bias.
810  math::Gemv<T, Context>(
811  CblasTrans,
812  shape.N,
813  shape.output_image_size * shape.M,
814  1.0f,
815  dY_data,
816  bias_multiplier_.template data<T>(),
817  0.0f,
818  dbias_data,
819  &context_);
820  }
821 
822  if (dX_data != nullptr) {
823  // Gradient respect to X.
824  math::GemmBatched(
825  CblasNoTrans,
826  CblasNoTrans,
827  shape.output_image_size,
828  shape.N,
829  shape.kernel_size,
830  shape.M,
831  1.0f,
832  dY_transposed_buffer_data,
833  filter_data,
834  0.0f,
835  column_transposed_buffer->template mutable_data<T>(),
836  &context_);
837  math::Transpose(
838  shape.column_dims.size(),
839  column_transposed_dims_device_.template data<int>(),
840  column_dims_device_.template data<int>(),
841  column_transposed_axes_device_.template data<int>(),
842  column_transposed_buffer->size(),
843  column_transposed_buffer->template data<T>(),
844  column_buffer->template mutable_data<T>(),
845  &context_);
846  const T* const_column_buffer_data = column_buffer->template data<T>();
847  for (int image_id = 0; image_id < shape.N; ++image_id) {
848  math::Col2im<T, Context, StorageOrder::NHWC>(
849  const_column_buffer_data,
850  shape.C,
851  shape.input_image_dims[0],
852  shape.input_image_dims[1],
853  kernel_h(),
854  kernel_w(),
855  dilation_h(),
856  dilation_w(),
857  pad_t(),
858  pad_l(),
859  pad_b(),
860  pad_r(),
861  stride_h(),
862  stride_w(),
863  dX_data,
864  &context_);
865  dX_data += input_stride;
866  const_column_buffer_data += column_stride;
867  }
868  }
869 }
870 
871 template <typename T, class Context>
872 void LocallyConnectedGradientOp<T, Context>::SetColumnBufferShape(
873  const int N,
874  const int C,
875  const int kernel_size,
876  const int output_image_size,
877  std::vector<int>* column_dims,
878  std::vector<int>* column_transposed_dims) {
879  std::vector<int> column_axes;
880  std::vector<int> column_transposed_axes;
881  lc_op_util::SetColumnBufferShapeImpl(
882  N,
883  kernel_size,
884  output_image_size,
885  order_,
886  column_dims,
887  column_transposed_dims,
888  &column_axes,
889  &column_transposed_axes);
890  SetDeviceTensor(*column_dims, &column_dims_device_);
891  SetDeviceTensor(*column_transposed_dims, &column_transposed_dims_device_);
892  SetDeviceTensor(column_axes, &column_axes_device_);
893  SetDeviceTensor(column_transposed_axes, &column_transposed_axes_device_);
894 }
895 
896 template <typename T, class Context>
897 void LocallyConnectedGradientOp<T, Context>::SetDYTranposedBufferShape(
898  const int N,
899  const int M,
900  const int output_image_size,
901  std::vector<int>* dY_transposed_dims) {
902  std::vector<int> dY_dims;
903  std::vector<int> dY_axes;
904  lc_op_util::SetYBufferShapeImpl(
905  N,
906  M,
907  output_image_size,
908  order_,
909  &dY_dims,
910  dY_transposed_dims,
911  &dY_axes,
912  nullptr);
913  SetDeviceTensor(dY_dims, &dY_dims_device_);
914  SetDeviceTensor(*dY_transposed_dims, &dY_transposed_dims_device_);
915  SetDeviceTensor(dY_axes, &dY_axes_device_);
916 }
917 
918 } // namespace caffe2
919 
920 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Commandline flags support for Caffe2.