1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ 2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/conversions.h" 7 #include "caffe2/utils/math.h" 14 class Engine = DefaultEngine,
15 bool TransposeWeight =
true>
18 USE_OPERATOR_CONTEXT_FUNCTIONS;
21 axis_(OperatorBase::GetSingleArgument<int32_t>(
"axis", 1)),
22 axis_w_(OperatorBase::GetSingleArgument<int32_t>(
"axis_w", 1)),
24 OperatorBase::GetSingleArgument<bool>(
"float16_compute",
false)) {}
33 bool DoRunWithType() {
34 const auto& X = Input(0);
35 const auto& W = Input(1);
36 const auto& b = Input(2);
38 CAFFE_ENFORCE(b.ndim() == 1, b.ndim());
40 const auto canonical_axis = X.canonical_axis_index(axis_);
41 const auto M = X.size_to_dim(canonical_axis);
42 const auto K = X.size_from_dim(canonical_axis);
43 const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
44 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
45 : W.size_from_dim(canonical_axis_w);
47 auto dimErrorString = [&]() {
49 "Dimension mismatch: ",
67 CAFFE_ENFORCE(M == X.size() / K, dimErrorString());
68 CAFFE_ENFORCE(K == W.size() / N, dimErrorString());
69 CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
70 CAFFE_ENFORCE(N == b.size(), dimErrorString());
72 Y_shape_cache_ = X.dims();
74 DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
75 Y_shape_cache_.resize(canonical_axis + 1);
76 Y_shape_cache_[canonical_axis] = N;
77 Y->Resize(Y_shape_cache_);
78 CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
82 Y->template mutable_data<T_Y>();
87 TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
88 if (fp16_type<MATH>()) {
89 math_type = TensorProto_DataType_FLOAT16;
93 math::Gemm<T_X, Context, Engine>(
95 TransposeWeight ? CblasTrans : CblasNoTrans,
100 X.template data<T_X>(),
101 W.template data<T_W>(),
103 Y->template mutable_data<T_Y>(),
107 if (bias_multiplier_.size() != M) {
109 bias_multiplier_.Resize(M);
110 math::Set<T_B, Context>(
112 convert::To<float, T_B>(1),
113 bias_multiplier_.template mutable_data<T_B>(),
116 math::Gemm<T_B, Context, Engine>(
123 bias_multiplier_.template data<T_B>(),
124 b.template data<T_B>(),
126 Y->template mutable_data<T_Y>(),
132 bool RunOnDevice()
override {
133 return DoRunWithType<
146 vector<TIndex> Y_shape_cache_;
149 bool float16_compute_;
155 bool TransposeWeight =
true>
158 USE_OPERATOR_CONTEXT_FUNCTIONS;
161 axis_(OperatorBase::GetSingleArgument<int32_t>(
"axis", 1)),
162 axis_w_(OperatorBase::GetSingleArgument<int32_t>(
"axis_w", 1)),
164 OperatorBase::GetSingleArgument<bool>(
"float16_compute",
false)) {}
176 bool DoRunWithType() {
177 const auto& X = Input(0);
178 const auto& W = Input(1);
179 const auto& dY = Input(2);
181 const auto canonical_axis = X.canonical_axis_index(axis_);
182 const int M = X.size_to_dim(canonical_axis);
183 const int K = X.size_from_dim(canonical_axis);
184 const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
185 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
186 : W.size_from_dim(canonical_axis_w);
187 CAFFE_ENFORCE(M * K == X.size());
188 CAFFE_ENFORCE(K * N == W.size());
190 auto* dW = Output(0);
191 auto* db = Output(1);
197 math::Set<T_DB, Context>(
199 convert::To<float, T_DB>(0),
200 db->template mutable_data<T_DB>(),
202 math::Set<T_DW, Context>(
204 convert::To<float, T_DW>(0),
205 dW->template mutable_data<T_DW>(),
208 if (OutputSize() == 3) {
209 auto* dX = Output(2);
211 dX->template mutable_data<T_DX>();
218 TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
219 if (fp16_type<MATH>()) {
220 math_type = TensorProto_DataType_FLOAT16;
224 math::Gemm<T_DY, Context, Engine>(
227 TransposeWeight ? N : K,
228 TransposeWeight ? K : N,
231 TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
232 TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
234 dW->template mutable_data<T_DW>(),
237 if (bias_multiplier_.size() != M) {
240 bias_multiplier_.Resize(M);
241 math::Set<T_B, Context>(
243 convert::To<float, T_B>(1),
244 bias_multiplier_.template mutable_data<T_B>(),
248 math::Gemv<T_DY, Context>(
253 dY.template data<T_DY>(),
254 bias_multiplier_.template data<T_B>(),
256 db->template mutable_data<T_DB>(),
260 if (OutputSize() == 3) {
261 auto* dX = Output(2);
263 math::Gemm<T_DX, Context, Engine>(
265 TransposeWeight ? CblasNoTrans : CblasTrans,
270 dY.template data<T_DY>(),
271 W.template data<T_W>(),
273 dX->template mutable_data<T_DX>(),
280 bool RunOnDevice()
override {
281 return DoRunWithType<
296 bool float16_compute_;
301 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...