4 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ 5 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ 9 #include "caffe2/core/context.h" 11 #include "caffe2/core/logging.h" 12 #include "caffe2/core/operator.h" 13 #include "caffe2/operators/conv_pool_op_base.h" 14 #include "caffe2/operators/locally_connected_op.h" 15 #include "caffe2/utils/math.h" 19 template <
typename T,
class Context>
20 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
21 const auto& X = Input(INPUT);
22 const auto& filter = Input(FILTER);
24 const int image_ndim = X.ndim() - 2;
25 CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
26 lc_op_util::ShapeParams shape;
29 shape.M = filter.dim32(image_ndim);
31 shape.C == filter.dim32(image_ndim + 1) * group_,
32 "Locally Connected op: input channels does not match: " 33 "# of input channels ",
35 " is not equal to kernel channels * group:",
36 filter.dim32(image_ndim + 1),
40 shape.M % group_ == 0,
41 "The number of output channels is not divisible by group.");
43 ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
44 shape.input_image_size = GetDimsSize(X);
45 shape.output_image_size = GetDimsSize(*Y);
46 const std::vector<int> output_image_dims = GetDims(*Y);
47 for (
int i = 0; i < image_ndim; ++i) {
48 CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
51 int kernel_dims_size = 1;
52 for (std::size_t i = 0; i < kernel_.size(); ++i) {
53 CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
54 kernel_dims_size *= kernel_[i];
57 shape.input_image_dims = GetDims(X);
58 const std::vector<int> X_dims(X.dims().cbegin() + 1, X.dims().cend());
59 SetDeviceTensor(X_dims, &X_dims_device_);
60 shape.kernel_size = shape.C / group_ * kernel_dims_size;
66 shape.output_image_size,
68 &shape.column_transposed_dims);
69 SetYTranposedBufferShape(
70 shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
72 const T* X_data = X.template data<T>();
73 const T* filter_data = filter.template data<T>();
74 const T* bias_data =
nullptr;
75 if (InputSize() == 3) {
76 const auto& bias = Input(BIAS);
77 CAFFE_ENFORCE(bias.ndim() == image_ndim + 1);
78 for (
int i = 0; i < image_ndim; ++i) {
79 CAFFE_ENFORCE(bias.dim32(i) == output_image_dims[i]);
81 CAFFE_ENFORCE(bias.dim32(image_ndim) == shape.M);
82 bias_data = bias.template data<T>();
83 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
84 shape.N, &bias_multiplier_);
86 T* Y_data = Y->template mutable_data<T>();
88 RunOnDeviceWithOrderNCHWImpl(
95 &column_transposed_buffer_,
96 &Y_transposed_buffer_);
101 template <
typename T,
class Context>
102 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
103 const auto& X = Input(INPUT);
104 const auto& filter = Input(FILTER);
109 "Only 2d locally connected op is supported for NHWC storage type.");
110 const int image_ndim = X.ndim() - 2;
111 CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
112 lc_op_util::ShapeParams shape;
113 shape.N = X.dim32(0);
114 shape.C = X.dim32(3);
115 shape.input_image_dims = {X.dim32(1), X.dim32(2)};
116 shape.M = filter.dim32(image_ndim);
117 CAFFE_ENFORCE(filter.dim32(image_ndim + 1) == kernel_h());
118 CAFFE_ENFORCE(filter.dim32(image_ndim + 2) == kernel_w());
119 CAFFE_ENFORCE(filter.dim32(image_ndim + 3) == shape.C);
120 ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
122 shape.input_image_size = GetDimsSize(X);
123 shape.output_image_size = GetDimsSize(*Y);
124 const std::vector<int> output_image_dims = GetDims(*Y);
125 for (
int i = 0; i < image_ndim; ++i) {
126 CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
129 shape.kernel_size = kernel_h() * kernel_w() * shape.C;
130 SetColumnBufferShape(
134 shape.output_image_size,
136 &shape.column_transposed_dims);
137 SetYTranposedBufferShape(
138 shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
140 const T* X_data = X.template data<T>();
141 const T* filter_data = filter.template data<T>();
142 const T* bias_data =
nullptr;
143 if (InputSize() == 3) {
144 const auto& bias = Input(BIAS);
145 CAFFE_ENFORCE(bias.ndim() == image_ndim + 1);
146 for (
int i = 0; i < image_ndim; ++i) {
147 CAFFE_ENFORCE(bias.dim32(i) == output_image_dims[i]);
149 CAFFE_ENFORCE(bias.dim32(image_ndim) == shape.M);
150 bias_data = bias.template data<T>();
151 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
152 shape.N, &bias_multiplier_);
154 T* Y_data = Y->template mutable_data<T>();
156 RunOnDeviceWithOrderNHWCImpl(
163 &column_transposed_buffer_,
164 &Y_transposed_buffer_);
169 template <
typename T,
class Context>
170 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
171 const lc_op_util::ShapeParams& shape,
173 const T* filter_data,
179 const int input_stride = shape.C / group_ * shape.input_image_size;
180 const int column_stride = shape.kernel_size * shape.output_image_size;
181 column_buffer->Resize(shape.column_dims);
182 column_transposed_buffer->Resize(shape.column_transposed_dims);
183 Y_transposed_buffer->Resize(shape.Y_transposed_dims);
184 T* column_buffer_data = column_buffer->template mutable_data<T>();
185 T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
187 for (
int image_id = 0; image_id < shape.N; ++image_id) {
188 for (
int group_id = 0; group_id < group_; ++group_id) {
189 if (kernel_.size() == 2) {
190 math::Im2col<T, Context, StorageOrder::NCHW>(
191 X_data + group_id * input_stride,
193 shape.input_image_dims[0],
194 shape.input_image_dims[1],
205 column_buffer_data + group_id * column_stride,
208 math::Im2colNd<T, Context, StorageOrder::NCHW>(
209 X_data + group_id * input_stride,
210 X_dims_device_.template data<int>(),
211 column_dims_device_.template data<int>() + 1,
212 shape.C * shape.input_image_size,
214 kernel_device_.template data<int>(),
215 stride_device_.template data<int>(),
216 dilation_device_.template data<int>(),
217 pads_device_.template data<int>(),
219 column_buffer_data + group_id * column_stride,
223 X_data += input_stride * group_;
224 column_buffer_data += column_stride * group_;
227 shape.column_dims.size(),
228 column_dims_device_.template data<int>(),
229 column_transposed_dims_device_.template data<int>(),
230 column_axes_device_.template data<int>(),
231 column_buffer->size(),
232 column_buffer->template data<T>(),
233 column_transposed_buffer->template mutable_data<T>(),
238 shape.output_image_size * group_,
244 column_transposed_buffer->template data<T>(),
246 Y_transposed_buffer_data,
248 if (bias_data !=
nullptr) {
249 math::Gemm<T, Context>(
252 shape.output_image_size * shape.M,
257 bias_multiplier_.template data<T>(),
259 Y_transposed_buffer_data,
263 shape.Y_transposed_dims.size(),
264 Y_transposed_dims_device_.template data<int>(),
265 Y_dims_device_.template data<int>(),
266 Y_transposed_axes_device_.template data<int>(),
267 Y_transposed_buffer->size(),
268 Y_transposed_buffer_data,
273 template <
typename T,
class Context>
274 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
275 const lc_op_util::ShapeParams& shape,
277 const T* filter_data,
283 const int input_stride = shape.C * shape.input_image_size;
284 const int column_stride = shape.kernel_size * shape.output_image_size;
285 column_buffer->Resize(shape.column_dims);
286 column_transposed_buffer->Resize(shape.column_transposed_dims);
287 Y_transposed_buffer->Resize(shape.Y_transposed_dims);
288 T* column_buffer_data = column_buffer->template mutable_data<T>();
289 T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
290 for (
int image_id = 0; image_id < shape.N; ++image_id) {
291 math::Im2col<T, Context, StorageOrder::NHWC>(
292 X_data + image_id * input_stride,
294 shape.input_image_dims[0],
295 shape.input_image_dims[1],
306 column_buffer_data + image_id * column_stride,
310 shape.column_dims.size(),
311 column_dims_device_.template data<int>(),
312 column_transposed_dims_device_.template data<int>(),
313 column_axes_device_.template data<int>(),
314 column_buffer->size(),
315 column_buffer->template data<T>(),
316 column_transposed_buffer->template mutable_data<T>(),
321 shape.output_image_size,
326 column_transposed_buffer->template data<T>(),
329 Y_transposed_buffer_data,
332 shape.Y_transposed_dims.size(),
333 Y_transposed_dims_device_.template data<int>(),
334 Y_dims_device_.template data<int>(),
335 Y_transposed_axes_device_.template data<int>(),
336 Y_transposed_buffer->size(),
337 Y_transposed_buffer_data,
340 if (bias_data !=
nullptr) {
341 math::Gemm<T, Context>(
345 shape.output_image_size * shape.M,
348 bias_multiplier_.template data<T>(),
356 template <
typename T,
class Context>
357 void LocallyConnectedOp<T, Context>::SetColumnBufferShape(
360 const int kernel_size,
361 const int output_image_size,
362 std::vector<int>* column_dims,
363 std::vector<int>* column_transposed_dims) {
364 std::vector<int> column_axes;
365 lc_op_util::SetColumnBufferShapeImpl(
371 column_transposed_dims,
374 SetDeviceTensor(*column_dims, &column_dims_device_);
375 SetDeviceTensor(*column_transposed_dims, &column_transposed_dims_device_);
376 SetDeviceTensor(column_axes, &column_axes_device_);
379 template <
typename T,
class Context>
380 void LocallyConnectedOp<T, Context>::SetYTranposedBufferShape(
383 const int output_image_size,
384 std::vector<int>* Y_transposed_dims) {
385 std::vector<int> Y_dims;
386 std::vector<int> Y_transposed_axes;
387 lc_op_util::SetYBufferShapeImpl(
396 SetDeviceTensor(Y_dims, &Y_dims_device_);
397 SetDeviceTensor(*Y_transposed_dims, &Y_transposed_dims_device_);
398 SetDeviceTensor(Y_transposed_axes, &Y_transposed_axes_device_);
401 template <
typename T,
class Context>
402 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
403 const auto& X = Input(INPUT);
404 const auto& filter = Input(FILTER);
405 const auto& dY = Input(OUTPUT_GRAD);
406 auto* dfilter = Output(FILTER_GRAD);
407 const int image_ndim = X.ndim() - 2;
408 CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
410 lc_op_util::ShapeParams shape;
411 shape.N = X.dim32(0);
412 shape.C = X.dim32(1);
413 shape.M = filter.dim32(image_ndim);
414 CAFFE_ENFORCE(filter.dim32(image_ndim + 1) * group_ == shape.C);
415 CAFFE_ENFORCE(shape.M % group_ == 0);
417 shape.input_image_dims = GetDims(X);
418 shape.input_image_size = GetDimsSize(X);
419 const std::vector<int> output_image_dims = GetDims(dY);
420 shape.output_image_size = GetDimsSize(dY);
421 for (
int i = 0; i < image_ndim; ++i) {
422 CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
424 ConvPoolOpBase<Context>::ComputePads(shape.input_image_dims);
426 int kernel_dims_size = 1;
427 for (std::size_t i = 0; i < kernel_.size(); ++i) {
428 CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
429 kernel_dims_size *= kernel_[i];
432 const std::vector<int> X_dims(X.dims().cbegin() + 1, X.dims().cend());
433 SetDeviceTensor(X_dims, &X_dims_device_);
434 shape.kernel_size = shape.C / group_ * kernel_dims_size;
436 SetColumnBufferShape(
440 shape.output_image_size,
442 &shape.column_transposed_dims);
443 SetDYTranposedBufferShape(
444 shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
446 dfilter->ResizeLike(filter);
447 const T* X_data = X.template data<T>();
448 const T* filter_data = filter.template data<T>();
449 const T* dY_data = dY.template data<T>();
450 T* dfilter_data = dfilter->template mutable_data<T>();
451 T* dX_data =
nullptr;
452 T* dbias_data =
nullptr;
453 if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
454 auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
456 dX_data = dX->template mutable_data<T>();
459 auto* dbias = Output(BIAS_OR_INPUT_GRAD);
460 std::vector<int> dbias_dims = output_image_dims;
461 dbias_dims.push_back(shape.M);
462 dbias->Resize(dbias_dims);
463 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
464 shape.N, &bias_multiplier_);
465 dbias_data = dbias->template mutable_data<T>();
467 RunOnDeviceWithOrderNCHWImpl(
476 &column_transposed_buffer_,
477 &dY_transposed_buffer_);
482 template <
typename T,
class Context>
483 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
484 const auto& X = Input(INPUT);
485 const auto& filter = Input(FILTER);
486 const auto& dY = Input(OUTPUT_GRAD);
487 auto* dfilter = Output(FILTER_GRAD);
491 "Only 2d locally connected op is supported for NHWC storage type.");
492 const int image_ndim = X.ndim() - 2;
493 CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
494 lc_op_util::ShapeParams shape;
495 shape.N = X.dim32(0);
496 shape.C = X.dim32(3);
497 shape.input_image_dims = {X.dim32(1), X.dim32(2)};
498 shape.M = filter.dim32(image_ndim);
499 CAFFE_ENFORCE(filter.dim32(image_ndim + 1) == kernel_h());
500 CAFFE_ENFORCE(filter.dim32(image_ndim + 2) == kernel_w());
501 CAFFE_ENFORCE(filter.dim32(image_ndim + 3) == shape.C);
502 ConvPoolOpBase<Context>::ComputePads(shape.input_image_dims);
504 shape.input_image_size = GetDimsSize(X);
505 shape.output_image_size = GetDimsSize(dY);
506 const std::vector<int> output_image_dims = GetDims(dY);
507 for (
int i = 0; i < image_ndim; ++i) {
508 CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
511 shape.kernel_size = kernel_h() * kernel_w() * shape.C;
512 SetColumnBufferShape(
516 shape.output_image_size,
518 &shape.column_transposed_dims);
519 SetDYTranposedBufferShape(
520 shape.N, shape.M, shape.output_image_size, &shape.Y_transposed_dims);
522 dfilter->ResizeLike(filter);
523 const T* X_data = X.template data<T>();
524 const T* filter_data = filter.template data<T>();
525 const T* dY_data = dY.template data<T>();
526 T* dfilter_data = dfilter->template mutable_data<T>();
527 T* dX_data =
nullptr;
528 T* dbias_data =
nullptr;
529 if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
530 auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
532 dX_data = dX->template mutable_data<T>();
535 auto* dbias = Output(BIAS_OR_INPUT_GRAD);
536 std::vector<int> dbias_dims = output_image_dims;
537 dbias_dims.push_back(shape.M);
538 dbias->Resize(dbias_dims);
539 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
540 shape.N, &bias_multiplier_);
541 dbias_data = dbias->template mutable_data<T>();
543 RunOnDeviceWithOrderNHWCImpl(
552 &column_transposed_buffer_,
553 &dY_transposed_buffer_);
558 template <
typename T,
class Context>
559 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
560 const lc_op_util::ShapeParams& shape,
562 const T* filter_data,
570 const int input_stride = shape.C * shape.input_image_size;
571 const int column_stride = shape.kernel_size * shape.output_image_size;
572 column_buffer->Resize(shape.column_dims);
573 column_transposed_buffer->Resize(shape.column_transposed_dims);
574 dY_transposed_buffer->Resize(shape.Y_transposed_dims);
575 T* column_buffer_data = column_buffer->template mutable_data<T>();
576 T* dY_transposed_buffer_data =
577 dY_transposed_buffer->template mutable_data<T>();
579 for (
int image_id = 0; image_id < shape.N; ++image_id) {
580 for (
int group_id = 0; group_id < group_; ++group_id) {
581 if (kernel_.size() == 2) {
582 math::Im2col<T, Context, StorageOrder::NCHW>(
583 X_data + group_id * input_stride,
585 shape.input_image_dims[0],
586 shape.input_image_dims[1],
597 column_buffer_data + group_id * column_stride,
600 math::Im2colNd<T, Context, StorageOrder::NCHW>(
601 X_data + group_id * input_stride,
602 X_dims_device_.template data<int>(),
603 column_dims_device_.template data<int>() + 1,
604 shape.C * shape.input_image_size,
606 kernel_device_.template data<int>(),
607 stride_device_.template data<int>(),
608 dilation_device_.template data<int>(),
609 pads_device_.template data<int>(),
611 column_buffer_data + group_id * column_stride,
615 X_data += input_stride * group_;
616 column_buffer_data += column_stride * group_;
619 shape.column_dims.size(),
620 column_dims_device_.template data<int>(),
621 column_transposed_dims_device_.template data<int>(),
622 column_axes_device_.template data<int>(),
623 column_buffer->size(),
624 column_buffer->template data<T>(),
625 column_transposed_buffer->template mutable_data<T>(),
629 shape.Y_transposed_dims.size(),
630 dY_dims_device_.template data<int>(),
631 dY_transposed_dims_device_.template data<int>(),
632 dY_axes_device_.template data<int>(),
633 dY_transposed_buffer->size(),
635 dY_transposed_buffer_data,
642 shape.output_image_size * group_,
647 dY_transposed_buffer_data,
648 column_transposed_buffer->template data<T>(),
653 if (dbias_data !=
nullptr) {
655 math::Gemv<T, Context>(
657 shape.output_image_size * shape.M,
660 dY_transposed_buffer_data,
661 bias_multiplier_.template data<T>(),
667 if (dX_data !=
nullptr) {
672 shape.output_image_size * group_,
678 dY_transposed_buffer_data,
680 column_transposed_buffer->template mutable_data<T>(),
683 shape.column_dims.size(),
684 column_transposed_dims_device_.template data<int>(),
685 column_dims_device_.template data<int>(),
686 column_transposed_axes_device_.template data<int>(),
687 column_transposed_buffer->size(),
688 column_transposed_buffer->template data<T>(),
689 column_buffer->template mutable_data<T>(),
691 const T* const_column_buffer_data = column_buffer->template data<T>();
692 for (
int image_id = 0; image_id < shape.N; ++image_id) {
693 for (
int group_id = 0; group_id < group_; ++group_id) {
694 if (kernel_.size() == 2) {
695 math::Col2im<T, Context, StorageOrder::NCHW>(
696 const_column_buffer_data + group_id * column_stride,
698 shape.input_image_dims[0],
699 shape.input_image_dims[1],
710 dX_data + group_id * input_stride,
713 math::Col2imNd<T, Context, StorageOrder::NCHW>(
714 const_column_buffer_data + group_id * column_stride,
715 X_dims_device_.template data<int>(),
716 column_dims_device_.template data<int>() + 1,
717 shape.C * shape.input_image_size,
719 kernel_device_.template data<int>(),
720 stride_device_.template data<int>(),
721 dilation_device_.template data<int>(),
722 pads_device_.template data<int>(),
724 dX_data + group_id * input_stride,
728 dX_data += input_stride * group_;
729 const_column_buffer_data += column_stride * group_;
734 template <
typename T,
class Context>
735 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
736 const lc_op_util::ShapeParams& shape,
738 const T* filter_data,
746 const int input_stride = shape.C * shape.input_image_size;
747 const int column_stride = shape.kernel_size * shape.output_image_size;
748 column_buffer->Resize(shape.column_dims);
749 column_transposed_buffer->Resize(shape.column_transposed_dims);
750 dY_transposed_buffer->Resize(shape.Y_transposed_dims);
751 T* column_buffer_data = column_buffer->template mutable_data<T>();
752 T* dY_transposed_buffer_data =
753 dY_transposed_buffer->template mutable_data<T>();
754 for (
int image_id = 0; image_id < shape.N; ++image_id) {
755 math::Im2col<T, Context, StorageOrder::NHWC>(
756 X_data + image_id * input_stride,
758 shape.input_image_dims[0],
759 shape.input_image_dims[1],
770 column_buffer_data + image_id * column_stride,
774 shape.column_dims.size(),
775 column_dims_device_.template data<int>(),
776 column_transposed_dims_device_.template data<int>(),
777 column_axes_device_.template data<int>(),
778 column_buffer->size(),
779 column_buffer->template data<T>(),
780 column_transposed_buffer->template mutable_data<T>(),
784 shape.Y_transposed_dims.size(),
785 dY_dims_device_.template data<int>(),
786 dY_transposed_dims_device_.template data<int>(),
787 dY_axes_device_.template data<int>(),
788 dY_transposed_buffer->size(),
790 dY_transposed_buffer_data,
797 shape.output_image_size,
802 dY_transposed_buffer_data,
803 column_transposed_buffer->template data<T>(),
808 if (dbias_data !=
nullptr) {
810 math::Gemv<T, Context>(
813 shape.output_image_size * shape.M,
816 bias_multiplier_.template data<T>(),
822 if (dX_data !=
nullptr) {
827 shape.output_image_size,
832 dY_transposed_buffer_data,
835 column_transposed_buffer->template mutable_data<T>(),
838 shape.column_dims.size(),
839 column_transposed_dims_device_.template data<int>(),
840 column_dims_device_.template data<int>(),
841 column_transposed_axes_device_.template data<int>(),
842 column_transposed_buffer->size(),
843 column_transposed_buffer->template data<T>(),
844 column_buffer->template mutable_data<T>(),
846 const T* const_column_buffer_data = column_buffer->template data<T>();
847 for (
int image_id = 0; image_id < shape.N; ++image_id) {
848 math::Col2im<T, Context, StorageOrder::NHWC>(
849 const_column_buffer_data,
851 shape.input_image_dims[0],
852 shape.input_image_dims[1],
865 dX_data += input_stride;
866 const_column_buffer_data += column_stride;
871 template <
typename T,
class Context>
872 void LocallyConnectedGradientOp<T, Context>::SetColumnBufferShape(
875 const int kernel_size,
876 const int output_image_size,
877 std::vector<int>* column_dims,
878 std::vector<int>* column_transposed_dims) {
879 std::vector<int> column_axes;
880 std::vector<int> column_transposed_axes;
881 lc_op_util::SetColumnBufferShapeImpl(
887 column_transposed_dims,
889 &column_transposed_axes);
890 SetDeviceTensor(*column_dims, &column_dims_device_);
891 SetDeviceTensor(*column_transposed_dims, &column_transposed_dims_device_);
892 SetDeviceTensor(column_axes, &column_axes_device_);
893 SetDeviceTensor(column_transposed_axes, &column_transposed_axes_device_);
896 template <
typename T,
class Context>
897 void LocallyConnectedGradientOp<T, Context>::SetDYTranposedBufferShape(
900 const int output_image_size,
901 std::vector<int>* dY_transposed_dims) {
902 std::vector<int> dY_dims;
903 std::vector<int> dY_axes;
904 lc_op_util::SetYBufferShapeImpl(
913 SetDeviceTensor(dY_dims, &dY_dims_device_);
914 SetDeviceTensor(*dY_transposed_dims, &dY_transposed_dims_device_);
915 SetDeviceTensor(dY_axes, &dY_axes_device_);
920 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Commandline flags support for Caffe2.