1 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_ 2 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/operators/conv_op_shared.h" 9 #include "caffe2/operators/conv_pool_op_base.h" 10 #include "caffe2/operators/locally_connected_op_util.h" 14 template <
typename T,
class Context>
17 USE_CONV_POOL_BASE_FUNCTIONS(Context);
24 group_ == 1 || order_ == StorageOrder::NCHW,
25 "Group locally connected only supports NCHW order right now.");
30 bool RunOnDeviceWithOrderNCHW()
override;
31 bool RunOnDeviceWithOrderNHWC()
override;
34 void RunOnDeviceWithOrderNCHWImpl(
44 void RunOnDeviceWithOrderNHWCImpl(
54 void SetColumnBufferShape(
57 const int kernel_size,
58 const int output_image_size,
59 std::vector<int>* column_dims,
60 std::vector<int>* column_transposed_dims);
62 void SetYTranposedBufferShape(
65 const int output_image_size,
66 std::vector<int>* Y_transposed_dims);
86 INPUT_TAGS(INPUT, FILTER, BIAS);
89 template <
typename T,
class Context>
92 USE_CONV_POOL_BASE_FUNCTIONS(Context);
96 no_bias_(OperatorBase::GetSingleArgument<int>(
"no_bias", 0)) {
98 !(no_bias_ && OutputSize() == 3),
99 "If bias is not present, you should not have 3 grad output.");
101 group_ == 1 || order_ == StorageOrder::NCHW,
102 "Group locally connected only supports NCHW order right now.");
107 bool RunOnDeviceWithOrderNCHW()
override;
108 bool RunOnDeviceWithOrderNHWC()
override;
111 void RunOnDeviceWithOrderNCHWImpl(
114 const T* filter_data,
123 void RunOnDeviceWithOrderNHWCImpl(
126 const T* filter_data,
135 void SetColumnBufferShape(
138 const int kernel_size,
139 const int output_image_size,
140 std::vector<int>* column_dims,
141 std::vector<int>* column_transposed_dims);
143 void SetDYTranposedBufferShape(
146 const int output_image_size,
147 std::vector<int>* dY_transposed_dims);
170 INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
171 OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
176 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...