Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_prune.h
1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10  namespace {
11 
12  template<int N>
13  using Shape = std::array<int, N>;
14 
15  template<int N>
16  const std::vector<TIndex>& shape(Shape<N> vs) {
17  static thread_local std::vector<TIndex> cache;
18  cache.resize(vs.size());
19  for (auto i = 0; i < vs.size(); ++i) {
20  cache[i] = vs[i];
21  }
22  return cache;
23  }
24 
25  inline const std::vector<TIndex>& shape(int i) {
26  return shape<1>(Shape<1>({i}));
27  }
28 
29  inline const std::vector<TIndex>& shape(int i, int j) {
30  return shape<2>(Shape<2>({i, j}));
31  }
32 
33  template <typename T, class Context>
34  void MaskMatrix(const T* mask, T* mat,
35  int M, int N);
36 
37  template <typename T, class Context>
38  void MaskMatrix_Inc(T* mask_seq, T* mat,
39  int M, int N, int seq_len, T target);
40 
41  template <typename T, class Context>
42  void AggrDW(T* ag_dw, const T* dw, int N, int K, Context* context);
43 
44  template <typename T>
45  int MatrixCompare_LT(const T* mat, float thres,
46  T* mask_seq, int M, int N);
47 
48  // TODO(wyiming): write an incremental Mask
49  // Incremental Mask: only give the new mask positions;
50  // Assuming that weights masked will not be mask again;
51  // The incremental mask can also be used to update mask matrix;
52  // But this will include template for bool and float;
53  template <>
54  void MaskMatrix<float, CPUContext>(
55  const float* mask, float* mat, int M, int N) {
56  int offset = 0;
57  for (int i = 0; i < M; ++i) {
58  for (int j = 0; j < N; ++j) {
59  mat[offset] = mask[offset]? mat[offset] : 0;
60  offset++;
61  }
62  }
63  }
64 
65  template <>
66  void MaskMatrix_Inc<float, CPUContext>(
67  float* mask_seq,
68  float* mat,
69  int /*M*/,
70  int /*N*/,
71  int seq_len,
72  float target) {
73  for (int i = 0; i < seq_len; ++i) {
74  // assume that the mask_seq is smaller than size
75  // Although it seems that random access gets bad performance,
76  // we make sure that seq is in order;
77  mat[static_cast<int>(mask_seq[i])] = target;
78  }
79  }
80 
81  template <>
82  void AggrDW<float, CPUContext>(
83  float* ag_dw, const float* dw,
84  int N, int K, CPUContext* context) {
85  math::Add<float, CPUContext>(N*K, dw, ag_dw, ag_dw, context);
86  }
87 
88  template <>
89  int MatrixCompare_LT<float>(
90  const float* mat, float thres,
91  float* mask_seq, int M, int N) {
92  int seq_len = 0;
93  int offset = 0;
94  for (int i = 0 ; i < M; ++i) {
95  for (int j = 0; j < N; ++j) {
96  if (mat[offset] != 0 &&
97  (mat[offset] < thres && mat[offset] > -thres)) {
98  mask_seq[seq_len++] = static_cast<float>(offset);
99  }
100  offset++;
101  }
102  }
103  return seq_len;
104  }
105 
106  }
107 
108  // This is Caffe's InnerProductOp, with a name that fits its purpose better.
109  template <typename T, class Context, class Engine=DefaultEngine>
110  class FullyConnectedOpPrune final : public Operator<Context> {
111  public:
112  USE_OPERATOR_CONTEXT_FUNCTIONS;
113  FullyConnectedOpPrune(const OperatorDef& operator_def, Workspace* ws)
114  : Operator<Context>(operator_def, ws) {}
116 
117  bool RunOnDevice() override {
118  const auto& X = Input(0);
119  const auto& W = Input(1);
120  const auto& Mask = Input(2);
121  const auto& b = Input(3);
122  auto* Y = Output(0);
123  CAFFE_ENFORCE_GE(X.ndim(), 1);
124  CAFFE_ENFORCE_GE(W.ndim(), 2);
125  if (X.ndim() > 2 || W.ndim() > 2) {
126  VLOG(1) << "Using legacy support for arbitrary input and weight "
127  "dimensions.";
128  }
129  CAFFE_ENFORCE_EQ(b.ndim(), 1);
130  // batch size
131  int M = X.ndim() > 1 ? X.dim32(0) : 1;
132  // Feature dimension
133  int K = X.size() / M;
134  // number of outputs.
135  int N = W.dim32(0);
136  CAFFE_ENFORCE_EQ(K, W.size() / W.dim32(0));
137  CAFFE_ENFORCE_EQ(N, b.dim32(0));
138  if (X.ndim() > 1) {
139  Y->Resize(M, N);
140  } else {
141  Y->Resize(N);
142  }
143  // W * x
144  math::Gemm<T, Context, Engine>(
145  CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
146  W.template data<T>(), 0, Y->template mutable_data<T>(),
147  &context_);
148  // Add bias term
149  if (bias_multiplier_.size() != M) {
150  // If the helper bias multiplier is not M,
151  // reshape and fill it with one.
152  bias_multiplier_.Resize(M);
153  math::Set<T, Context>(
154  M, static_cast<T>(1),
155  bias_multiplier_.template mutable_data<T>(),
156  &context_);
157  }
158  math::Gemm<T, Context, Engine>(
159  CblasNoTrans, CblasNoTrans, M, N, 1, 1,
160  bias_multiplier_.template data<T>(), b.template data<T>(), 1,
161  Y->template mutable_data<T>(), &context_);
162  if (OutputSize() == 2){
163  auto* Comp_rate = Output(1);
164  Comp_rate->Resize(vector<TIndex>());
165  T* comp_data = Comp_rate->template mutable_data<T>();
166  math::Sum<T, Context>(
167  Mask.size(), Mask.template data<T>(), comp_data, &context_);
168  math::Scale<T, Context>(
169  1, static_cast<T>(1.) / Mask.size(), comp_data, comp_data,
170  &context_);
171  }
172  return true;
173  }
174 
175  protected:
176  Tensor<Context> bias_multiplier_;
177  };
178 
179  template <typename T, class Context, class Engine=DefaultEngine>
180  class FullyConnectedPruneGradientOp : public Operator<Context> {
181  public:
182  int iter_offset;
183  public:
184  USE_OPERATOR_CONTEXT_FUNCTIONS;
186  (const OperatorDef& operator_def, Workspace* ws)
187  : Operator<Context>(operator_def, ws) { iter_offset = 0; }
189 
190  bool RunOnDevice() override {
191  const auto& X = Input(0);
192  //const auto& W = Input(1);
193  auto* W_ptr = Output(2);
194  auto& W = *W_ptr;
195  //const auto& Mask = Input(2);
196  auto* Mask_ptr = Output(3);
197  auto& Mask = *Mask_ptr;
198  const auto& dY = Input(3);
199  //const auto& Ag_dW = Input(4);
200  auto* Ag_dW_ptr = Output(4);
201  auto& Ag_dW = *Ag_dW_ptr;
202  // it is also the Input(5)
203  auto* mask_seq_auto = Output(5);
204  // how about get threshold
205  auto& thres = Input(6);
206  //TODO(wyiming): check comp_lb is a float
207  auto& comp_lb = Input(7);
208  DCHECK_GE(X.ndim(), 1);
209  DCHECK_GE(W.ndim(), 2);
210  DCHECK_LE(dY.ndim(), 2);
211  // batch size
212  int M = X.ndim() > 1 ? X.dim32(0) : 1;
213  // Feature dimension
214  int K = X.size() / M;
215  // number of outputs.
216  int N = W.dim32(0);
217  // TODO(wyiming): add this window_size to workspace?
218  int window_size = 100;
219  // TODO(wyiming): this threshold should be
220  // based on distribution of the layer weight
221  float thr = 0.01;
222  DCHECK_EQ(Mask.dim32(0), W.dim32(0));
223  DCHECK_EQ(Mask.dim32(1), W.dim32(1));
224  DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
225  DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
226  DCHECK_EQ(K, W.size() / W.dim32(0));
227  if (dY.ndim() > 1) {
228  DCHECK_EQ(M, dY.dim32(0));
229  DCHECK_EQ(N, dY.dim32(1));
230  } else {
231  DCHECK_EQ(X.ndim(), 1);
232  DCHECK_EQ(N, dY.size());
233  }
234  auto* dW = Output(0);
235  auto* db = Output(1);
236  dW->ResizeLike(W);
237  db->Resize(N);
238 
239  // Compute dW
240  math::Gemm<T, Context, Engine>(
241  CblasTrans, CblasNoTrans, N, K, M, 1,
242  dY.template data<T>(), X.template data<T>(),
243  0, dW->template mutable_data<T>(),
244  &context_);
245 
246  comp_r_buf_.Resize(vector<TIndex>());
247  T* comp_data = comp_r_buf_.template mutable_data<T>();
248  math::Sum<T, Context>(
249  Mask.size(), Mask.template data<T>(), comp_data, &context_);
250  math::Scale<T, Context>(
251  1, static_cast<T>(1.) / Mask.size(), comp_data, comp_data,
252  &context_);
253  // update W size window
254  // Notice here we need to maintain state in OP.
255  // This is new in Caffe2.
256  // And this is something we might need to discuss in the future.
257  // at most mask half of the matrix at time
258  // 1. mask dw with previous mask
259  MaskMatrix<T, Context>(Mask.template mutable_data<T>(),
260  dW->template mutable_data<T>(), N, K);
261  if(*comp_data > *(comp_lb.template data<T>())){
262  iter_offset++;
263  if (iter_offset % window_size == 0) {
264  // TODO(wyiming):do the prune here;
265  sum_buffer_.ResizeLike(W);
266  math::Add<T, Context>(W.size(),
267  W.template mutable_data<T>(),
268  Ag_dW.template mutable_data<T>(),
269  sum_buffer_.template mutable_data<T>(),
270  &context_);
271  mask_seq_auto->ResizeLike(W);
272  T* mask_seq = mask_seq_auto->template mutable_data<T>();
273  math::Set<T, Context>(N*K, static_cast<T>(0),
274  mask_seq_auto->template mutable_data<T>(), &context_);
275  // 2. find dw below thres but not eq 0
276  int seq_len = MatrixCompare_LT<T>(
277  Ag_dW_ptr->template mutable_data<T>(),
278  *thres.template data<T>(), mask_seq, N, K);
279  // 3. use the mask_seq to update W and dw
280  MaskMatrix_Inc<T, Context>(mask_seq,
281  dW->template mutable_data<T>(),
282  N, K, seq_len, 0);
283  MaskMatrix_Inc<T, Context>(mask_seq,
284  W.template mutable_data<T>(),
285  N, K, seq_len, 0);
286  MaskMatrix_Inc<T, Context>(mask_seq,
287  Mask.template mutable_data<T>(),
288  N, K, seq_len, 0);
289  math::Set<T, Context>(N*K, static_cast<T>(0),
290  Ag_dW.template mutable_data<T>(),
291  &context_);
292  } else {
293  // add dW to Aggregate dW.
294  AggrDW<T, Context>(
295  Ag_dW.template mutable_data<T>(),
296  dW->template mutable_data<T>(),
297  N, K, &context_);
298  }
299  }
300  if (bias_multiplier_.size() != M) {
301  // If the helper bias multiplier is not M,
302  // reshape and fill it with one.
303  bias_multiplier_.Resize(M);
304  math::Set<T, Context>(
305  M, static_cast<T>(1),
306  bias_multiplier_.template mutable_data<T>(),
307  &context_);
308  }
309  // Compute dB
310  math::Gemv<T, Context>(
311  CblasTrans, M, N, 1, dY.template data<T>(),
312  bias_multiplier_.template data<T>(), 0,
313  db->template mutable_data<T>(),
314  &context_);
315  // Compute dX if necessary.
316  if (OutputSize() == 7) {
317  auto* dX = Output(6);
318  dX->ResizeLike(X);
319  math::Gemm<T, Context, Engine>(
320  CblasNoTrans, CblasNoTrans, M, K, N, 1,
321  dY.template data<T>(), W.template data<T>(),
322  0, dX->template mutable_data<T>(),
323  &context_);
324  }
325 
326  return true;
327  }
328 
329  protected:
330  Tensor<Context> bias_multiplier_;
331  Tensor<Context> sum_buffer_;
332  Tensor<Context> comp_r_buf_;
333  };
334 
335 } // namespace caffe2
336 
337 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...