Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_decomposition.h
1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_
2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_DECOMPOSITION_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 /*
10  * Although a FC_decomp is just like 2 small FC,
11  * it is better to have it as one op for future analysis.
12  * And if we have 2 FC with bias, it is not right.
13  * TODO(wyiming): decompose the layer into 2 matrices
14  * W(N * K) = U(N * middle) * trans(V(K * middle))
15  * */
16 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
17 template <typename T, class Context, class Engine=DefaultEngine>
18 class FullyConnectedOpDecomp final : public Operator<Context> {
19  public:
20  USE_OPERATOR_CONTEXT_FUNCTIONS;
21  FullyConnectedOpDecomp(const OperatorDef& operator_def, Workspace* ws)
22  : Operator<Context>(operator_def, ws) {}
24 
25  bool RunOnDevice() override {
26  const auto& X = Input(0);
27  const auto& U = Input(1);
28  const auto& V = Input(2);
29  const auto& b = Input(3);
30  auto* Y = Output(0);
31  //auto* buffer_ptr = Output(1);
32  // Size M * middle;
33  //auto& multi_buffer_ = *buffer_ptr;
34  CAFFE_ENFORCE_GE(X.ndim(), 1);
35  CAFFE_ENFORCE_GE(U.ndim(), 2);
36  CAFFE_ENFORCE_GE(V.ndim(), 2);
37  if (X.ndim() > 2 || U.ndim() > 2 || V.ndim() > 2) {
38  VLOG(1) << "Using legacy support for arbitrary input and weight "
39  "dimensions.";
40  }
41  CAFFE_ENFORCE_EQ(b.ndim(), 1);
42  // batch size
43  int M = X.ndim() > 1 ? X.dim32(0) : 1;
44  // Feature dimension
45  int K = X.size() / M;
46  // number of outputs.
47  int N = U.dim32(0);
48  int middle = U.dim32(0);
49  CAFFE_ENFORCE_EQ(K, V.dim32(0));
50  CAFFE_ENFORCE_EQ(N, b.dim32(0));
51  if (X.ndim() > 1) {
52  Y->Resize(M, N);
53  multi_buffer_.Resize(M, middle);
54  } else {
55  Y->Resize(N);
56  multi_buffer_.Resize(middle);
57  }
58  // The col buffer is stored in CHW order as well - kernel_dim, and the height
59  // and width.
60  // multi_buffer_.Resize(M, middle);
61  T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
62  // X * V * tans(U)
63  math::Gemm<T, Context, Engine>(
64  CblasNoTrans, CblasNoTrans, M, middle, K, 1, X.template data<T>(),
65  V.template data<T>(), 0, multi_buffer_data,
66  &context_);
67  math::Gemm<T, Context, Engine>(
68  CblasNoTrans, CblasTrans, M, N, middle, 1, multi_buffer_data,
69  U.template data<T>(), 0, Y->template mutable_data<T>(),
70  &context_);
71  // Add bias term
72  if (bias_multiplier_.size() != M) {
73  // If the helper bias multiplier is not M, reshape and fill it with one.
74  bias_multiplier_.Resize(M);
75  math::Set<T, Context>(
76  M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
77  &context_);
78  }
79  math::Gemm<T, Context, Engine>(
80  CblasNoTrans, CblasNoTrans, M, N, 1, 1,
81  bias_multiplier_.template data<T>(), b.template data<T>(), 1,
82  Y->template mutable_data<T>(), &context_);
83  return true;
84  }
85 
86  protected:
87  Tensor<Context> bias_multiplier_;
88  Tensor<Context> multi_buffer_;
89 };
90 
91 template <typename T, class Context, class Engine=DefaultEngine>
92 class FullyConnectedDecompGradientOp : public Operator<Context> {
93  public:
94  USE_OPERATOR_CONTEXT_FUNCTIONS;
95  FullyConnectedDecompGradientOp(const OperatorDef& operator_def, Workspace* ws)
96  : Operator<Context>(operator_def, ws) {}
98 
99  bool RunOnDevice() override {
100  const auto& X = Input(0);
101  const auto& U = Input(1);
102  const auto& V = Input(2);
103  const auto& dY = Input(3);
104  DCHECK_GE(X.ndim(), 1);
105  DCHECK_GE(U.ndim(), 2);
106  DCHECK_GE(V.ndim(), 2);
107  DCHECK_LE(dY.ndim(), 2);
108  // batch size
109  int M = X.ndim() > 1 ? X.dim32(0) : 1;
110  // Feature dimension
111  int K = X.size() / M;
112  // number of outputs.
113  int N = U.dim32(0);
114  int middle = U.dim32(1);
115  DCHECK_EQ(K, V.dim32(0));
116  if (dY.ndim() > 1) {
117  DCHECK_EQ(M, dY.dim32(0));
118  DCHECK_EQ(N, dY.dim32(1));
119  } else {
120  DCHECK_EQ(X.ndim(), 1);
121  DCHECK_EQ(N, dY.size());
122  }
123  auto* dU = Output(0);
124  auto* dV = Output(1);
125  auto* db = Output(2);
126  dU->ResizeLike(U);
127  dV->ResizeLike(V);
128  db->Resize(N);
129 
130  // Compute dU
131  // first compute X * V
132  du_buffer_.Resize(N, middle);
133  T* du_buffer_data = du_buffer_.template mutable_data<T>();
134  math::Gemm<T, Context, Engine>(
135  CblasNoTrans, CblasNoTrans, M, middle, K, 1,
136  X.template data<T>(), V.template data<T>(),
137  0, du_buffer_data,
138  &context_);
139  math::Gemm<T, Context, Engine>(
140  CblasTrans, CblasNoTrans, N, middle, M, 1,
141  dY.template data<T>(), du_buffer_data,
142  0, dU->template mutable_data<T>(),
143  &context_);
144  // Compute dV
145  // first compute dY * U
146  dv_buffer_.Resize(M, middle);
147  T* dv_buffer_data = dv_buffer_.template mutable_data<T>();
148  math::Gemm<T, Context, Engine>(
149  CblasNoTrans, CblasNoTrans, M, middle, N, 1,
150  dY.template data<T>(), U.template data<T>(),
151  0, dv_buffer_data,
152  &context_);
153  math::Gemm<T, Context, Engine>(
154  CblasTrans, CblasNoTrans, K, middle, M, 1,
155  dY.template data<T>(), du_buffer_data,
156  0, dV->template mutable_data<T>(),
157  &context_);
158  if (bias_multiplier_.size() != M) {
159  // If the helper bias multiplier is not M, reshape and fill it with one.
160  bias_multiplier_.Resize(M);
161  math::Set<T, Context>(
162  M, static_cast<T>(1),
163  bias_multiplier_.template mutable_data<T>(),
164  &context_);
165  }
166  // Compute dB
167  math::Gemv<T, Context>(
168  CblasTrans, M, N, 1, dY.template data<T>(),
169  bias_multiplier_.template data<T>(), 0,
170  db->template mutable_data<T>(),
171  &context_);
172  // Compute dX if necessary.
173  if (OutputSize() == 4) {
174  auto* dX = Output(3);
175  dX->ResizeLike(X);
176  dx_buffer_.Resize(M, middle);
177  T* dx_buffer_data = dx_buffer_.template mutable_data<T>();
178  math::Gemm<T, Context, Engine>(
179  CblasNoTrans, CblasNoTrans, M, middle, N, 1,
180  dY.template data<T>(), U.template data<T>(),
181  0, dx_buffer_data,
182  &context_);
183  math::Gemm<T, Context, Engine>(
184  CblasNoTrans, CblasTrans, M, K, middle, 1,
185  dx_buffer_data, V.template data<T>(),
186  0, dX->template mutable_data<T>(),
187  &context_);
188  }
189 
190  return true;
191  }
192 
193  protected:
194  Tensor<Context> bias_multiplier_;
195  Tensor<Context> du_buffer_;
196  Tensor<Context> dv_buffer_;
197  Tensor<Context> dx_buffer_;
198 };
199 
200 } // namespace caffe2
201 
202 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...