Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op.h
1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/conversions.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
12 template <
13  class Context,
14  class Engine = DefaultEngine,
15  bool TransposeWeight = true>
16 class FullyConnectedOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19  FullyConnectedOp(const OperatorDef& operator_def, Workspace* ws)
20  : Operator<Context>(operator_def, ws),
21  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
22  axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
23  float16_compute_(
24  OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
25  ~FullyConnectedOp() {}
26 
27  template <
28  typename T_X,
29  typename T_W,
30  typename T_B,
31  typename T_Y,
32  typename MATH>
33  bool DoRunWithType() {
34  const auto& X = Input(0);
35  const auto& W = Input(1);
36  const auto& b = Input(2);
37  auto* Y = Output(0);
38  CAFFE_ENFORCE(b.ndim() == 1, b.ndim());
39  // batch size
40  const auto canonical_axis = X.canonical_axis_index(axis_);
41  const auto M = X.size_to_dim(canonical_axis);
42  const auto K = X.size_from_dim(canonical_axis);
43  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
44  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
45  : W.size_from_dim(canonical_axis_w);
46 
47  auto dimErrorString = [&]() {
48  return MakeString(
49  "Dimension mismatch: ",
50  "X: ",
51  X.dims(),
52  ", W: ",
53  W.dims(),
54  ", b: ",
55  b.dims(),
56  ", axis: ",
57  axis_,
58  ", M: ",
59  M,
60  ", N: ",
61  N,
62  ", K: ",
63  K);
64  };
65 
66  // Error checking
67  CAFFE_ENFORCE(M == X.size() / K, dimErrorString());
68  CAFFE_ENFORCE(K == W.size() / N, dimErrorString());
69  CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
70  CAFFE_ENFORCE(N == b.size(), dimErrorString());
71 
72  Y_shape_cache_ = X.dims();
73  // This is an invariant of canonical_axis, so we can DCHECK.
74  DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
75  Y_shape_cache_.resize(canonical_axis + 1);
76  Y_shape_cache_[canonical_axis] = N;
77  Y->Resize(Y_shape_cache_);
78  CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
79 
80  if (X.size() == 0) {
81  // skip the rest of the computation if X is empty
82  Y->template mutable_data<T_Y>();
83  return true;
84  }
85 
86  // default to FLOAT as math.h does.
87  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
88  if (fp16_type<MATH>()) {
89  math_type = TensorProto_DataType_FLOAT16;
90  }
91 
92  // W * x
93  math::Gemm<T_X, Context, Engine>(
94  CblasNoTrans,
95  TransposeWeight ? CblasTrans : CblasNoTrans,
96  M,
97  N,
98  K,
99  1,
100  X.template data<T_X>(),
101  W.template data<T_W>(),
102  0,
103  Y->template mutable_data<T_Y>(),
104  &context_,
105  math_type);
106  // Add bias term
107  if (bias_multiplier_.size() != M) {
108  // If the helper bias multiplier is not M, reshape and fill it with one.
109  bias_multiplier_.Resize(M);
110  math::Set<T_B, Context>(
111  M,
112  convert::To<float, T_B>(1),
113  bias_multiplier_.template mutable_data<T_B>(),
114  &context_);
115  }
116  math::Gemm<T_B, Context, Engine>(
117  CblasNoTrans,
118  CblasNoTrans,
119  M,
120  N,
121  1,
122  1,
123  bias_multiplier_.template data<T_B>(),
124  b.template data<T_B>(),
125  1,
126  Y->template mutable_data<T_Y>(),
127  &context_,
128  math_type);
129  return true;
130  }
131 
132  bool RunOnDevice() override {
133  return DoRunWithType<
134  float, // X
135  float, // W
136  float, // B
137  float, // Y
138  float>(); // Math
139  }
140 
141  protected:
142  size_t axis_{1};
143  size_t axis_w_{1};
144  // A local vector to cache the output shape so we don't need to recreate
145  // a vector object every time we run Run().
146  vector<TIndex> Y_shape_cache_;
147  Tensor<Context> bias_multiplier_;
148 
149  bool float16_compute_;
150 };
151 
152 template <
153  class Context,
154  class Engine = DefaultEngine,
155  bool TransposeWeight = true>
156 class FullyConnectedGradientOp : public Operator<Context> {
157  public:
158  USE_OPERATOR_CONTEXT_FUNCTIONS;
159  FullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws)
160  : Operator<Context>(operator_def, ws),
161  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
162  axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
163  float16_compute_(
164  OperatorBase::GetSingleArgument<bool>("float16_compute", false)) {}
166 
167  template <
168  typename T_X,
169  typename T_W,
170  typename T_DY,
171  typename T_B,
172  typename T_DX,
173  typename T_DW,
174  typename T_DB,
175  typename MATH>
176  bool DoRunWithType() {
177  const auto& X = Input(0);
178  const auto& W = Input(1);
179  const auto& dY = Input(2);
180  // batch size
181  const auto canonical_axis = X.canonical_axis_index(axis_);
182  const int M = X.size_to_dim(canonical_axis);
183  const int K = X.size_from_dim(canonical_axis);
184  const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
185  const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
186  : W.size_from_dim(canonical_axis_w);
187  CAFFE_ENFORCE(M * K == X.size());
188  CAFFE_ENFORCE(K * N == W.size());
189 
190  auto* dW = Output(0);
191  auto* db = Output(1);
192  dW->ResizeLike(W);
193  db->Resize(N);
194 
195  if (X.size() == 0) {
196  // generate a zero blob for db and dW when X is empty
197  math::Set<T_DB, Context>(
198  db->size(),
199  convert::To<float, T_DB>(0),
200  db->template mutable_data<T_DB>(),
201  &context_);
202  math::Set<T_DW, Context>(
203  dW->size(),
204  convert::To<float, T_DW>(0),
205  dW->template mutable_data<T_DW>(),
206  &context_);
207 
208  if (OutputSize() == 3) {
209  auto* dX = Output(2);
210  dX->ResizeLike(X);
211  dX->template mutable_data<T_DX>();
212  }
213 
214  return true;
215  }
216 
217  // default to FLOAT as math.h does.
218  TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
219  if (fp16_type<MATH>()) {
220  math_type = TensorProto_DataType_FLOAT16;
221  }
222 
223  // Compute dW
224  math::Gemm<T_DY, Context, Engine>(
225  CblasTrans,
226  CblasNoTrans,
227  TransposeWeight ? N : K,
228  TransposeWeight ? K : N,
229  M,
230  1,
231  TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
232  TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
233  0,
234  dW->template mutable_data<T_DW>(),
235  &context_,
236  math_type);
237  if (bias_multiplier_.size() != M) {
238  // If the helper bias multiplier is not M, reshape and fill it
239  // with one.
240  bias_multiplier_.Resize(M);
241  math::Set<T_B, Context>(
242  M,
243  convert::To<float, T_B>(1),
244  bias_multiplier_.template mutable_data<T_B>(),
245  &context_);
246  }
247  // Compute dB
248  math::Gemv<T_DY, Context>(
249  CblasTrans,
250  M,
251  N,
252  1,
253  dY.template data<T_DY>(),
254  bias_multiplier_.template data<T_B>(),
255  0,
256  db->template mutable_data<T_DB>(),
257  &context_);
258 
259  // Compute dX
260  if (OutputSize() == 3) {
261  auto* dX = Output(2);
262  dX->ResizeLike(X);
263  math::Gemm<T_DX, Context, Engine>(
264  CblasNoTrans,
265  TransposeWeight ? CblasNoTrans : CblasTrans,
266  M,
267  K,
268  N,
269  1,
270  dY.template data<T_DY>(),
271  W.template data<T_W>(),
272  0,
273  dX->template mutable_data<T_DX>(),
274  &context_,
275  math_type);
276  }
277  return true;
278  }
279 
280  bool RunOnDevice() override {
281  return DoRunWithType<
282  float, // X
283  float, // W
284  float, // dY
285  float, // B
286  float, // dX
287  float, // dW
288  float, // dB
289  float>(); // Math
290  }
291 
292  protected:
293  size_t axis_{1};
294  size_t axis_w_{1};
295  Tensor<Context> bias_multiplier_;
296  bool float16_compute_;
297 };
298 
299 } // namespace caffe2
300 
301 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...