1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_ 2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_PRUNE_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 13 using Shape = std::array<int, N>;
16 const std::vector<TIndex>& shape(Shape<N> vs) {
17 static thread_local std::vector<TIndex> cache;
18 cache.resize(vs.size());
19 for (
auto i = 0; i < vs.size(); ++i) {
25 inline const std::vector<TIndex>& shape(
int i) {
26 return shape<1>(Shape<1>({i}));
29 inline const std::vector<TIndex>& shape(
int i,
int j) {
30 return shape<2>(Shape<2>({i, j}));
33 template <
typename T,
class Context>
34 void MaskMatrix(
const T* mask, T* mat,
37 template <
typename T,
class Context>
38 void MaskMatrix_Inc(T* mask_seq, T* mat,
39 int M,
int N,
int seq_len, T target);
41 template <
typename T,
class Context>
42 void AggrDW(T* ag_dw,
const T* dw,
int N,
int K, Context* context);
45 int MatrixCompare_LT(
const T* mat,
float thres,
46 T* mask_seq,
int M,
int N);
54 void MaskMatrix<float, CPUContext>(
55 const float* mask,
float* mat,
int M,
int N) {
57 for (
int i = 0; i < M; ++i) {
58 for (
int j = 0; j < N; ++j) {
59 mat[offset] = mask[offset]? mat[offset] : 0;
66 void MaskMatrix_Inc<float, CPUContext>(
73 for (
int i = 0; i < seq_len; ++i) {
77 mat[
static_cast<int>(mask_seq[i])] = target;
82 void AggrDW<float, CPUContext>(
83 float* ag_dw,
const float* dw,
84 int N,
int K, CPUContext* context) {
85 math::Add<float, CPUContext>(N*K, dw, ag_dw, ag_dw, context);
89 int MatrixCompare_LT<float>(
90 const float* mat,
float thres,
91 float* mask_seq,
int M,
int N) {
94 for (
int i = 0 ; i < M; ++i) {
95 for (
int j = 0; j < N; ++j) {
96 if (mat[offset] != 0 &&
97 (mat[offset] < thres && mat[offset] > -thres)) {
98 mask_seq[seq_len++] =
static_cast<float>(offset);
109 template <
typename T,
class Context,
class Engine=DefaultEngine>
112 USE_OPERATOR_CONTEXT_FUNCTIONS;
117 bool RunOnDevice()
override {
118 const auto& X = Input(0);
119 const auto& W = Input(1);
120 const auto& Mask = Input(2);
121 const auto& b = Input(3);
123 CAFFE_ENFORCE_GE(X.ndim(), 1);
124 CAFFE_ENFORCE_GE(W.ndim(), 2);
125 if (X.ndim() > 2 || W.ndim() > 2) {
126 VLOG(1) <<
"Using legacy support for arbitrary input and weight " 129 CAFFE_ENFORCE_EQ(b.ndim(), 1);
131 int M = X.ndim() > 1 ? X.dim32(0) : 1;
133 int K = X.size() / M;
136 CAFFE_ENFORCE_EQ(K, W.size() / W.dim32(0));
137 CAFFE_ENFORCE_EQ(N, b.dim32(0));
144 math::Gemm<T, Context, Engine>(
145 CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
146 W.template data<T>(), 0, Y->template mutable_data<T>(),
149 if (bias_multiplier_.size() != M) {
152 bias_multiplier_.Resize(M);
153 math::Set<T, Context>(
154 M,
static_cast<T
>(1),
155 bias_multiplier_.template mutable_data<T>(),
158 math::Gemm<T, Context, Engine>(
159 CblasNoTrans, CblasNoTrans, M, N, 1, 1,
160 bias_multiplier_.template data<T>(), b.template data<T>(), 1,
161 Y->template mutable_data<T>(), &context_);
162 if (OutputSize() == 2){
163 auto* Comp_rate = Output(1);
164 Comp_rate->Resize(vector<TIndex>());
165 T* comp_data = Comp_rate->template mutable_data<T>();
166 math::Sum<T, Context>(
167 Mask.size(), Mask.template data<T>(), comp_data, &context_);
168 math::Scale<T, Context>(
169 1,
static_cast<T
>(1.) / Mask.size(), comp_data, comp_data,
179 template <
typename T,
class Context,
class Engine=DefaultEngine>
184 USE_OPERATOR_CONTEXT_FUNCTIONS;
186 (
const OperatorDef& operator_def,
Workspace* ws)
190 bool RunOnDevice()
override {
191 const auto& X = Input(0);
193 auto* W_ptr = Output(2);
196 auto* Mask_ptr = Output(3);
197 auto& Mask = *Mask_ptr;
198 const auto& dY = Input(3);
200 auto* Ag_dW_ptr = Output(4);
201 auto& Ag_dW = *Ag_dW_ptr;
203 auto* mask_seq_auto = Output(5);
205 auto& thres = Input(6);
207 auto& comp_lb = Input(7);
208 DCHECK_GE(X.ndim(), 1);
209 DCHECK_GE(W.ndim(), 2);
210 DCHECK_LE(dY.ndim(), 2);
212 int M = X.ndim() > 1 ? X.dim32(0) : 1;
214 int K = X.size() / M;
218 int window_size = 100;
222 DCHECK_EQ(Mask.dim32(0), W.dim32(0));
223 DCHECK_EQ(Mask.dim32(1), W.dim32(1));
224 DCHECK_EQ(Ag_dW.dim32(0), W.dim32(0));
225 DCHECK_EQ(Ag_dW.dim32(1), W.dim32(1));
226 DCHECK_EQ(K, W.size() / W.dim32(0));
228 DCHECK_EQ(M, dY.dim32(0));
229 DCHECK_EQ(N, dY.dim32(1));
231 DCHECK_EQ(X.ndim(), 1);
232 DCHECK_EQ(N, dY.size());
234 auto* dW = Output(0);
235 auto* db = Output(1);
240 math::Gemm<T, Context, Engine>(
241 CblasTrans, CblasNoTrans, N, K, M, 1,
242 dY.template data<T>(), X.template data<T>(),
243 0, dW->template mutable_data<T>(),
246 comp_r_buf_.Resize(vector<TIndex>());
247 T* comp_data = comp_r_buf_.template mutable_data<T>();
248 math::Sum<T, Context>(
249 Mask.size(), Mask.template data<T>(), comp_data, &context_);
250 math::Scale<T, Context>(
251 1,
static_cast<T
>(1.) / Mask.size(), comp_data, comp_data,
259 MaskMatrix<T, Context>(Mask.template mutable_data<T>(),
260 dW->template mutable_data<T>(), N, K);
261 if(*comp_data > *(comp_lb.template data<T>())){
263 if (iter_offset % window_size == 0) {
265 sum_buffer_.ResizeLike(W);
266 math::Add<T, Context>(W.size(),
267 W.template mutable_data<T>(),
268 Ag_dW.template mutable_data<T>(),
269 sum_buffer_.template mutable_data<T>(),
271 mask_seq_auto->ResizeLike(W);
272 T* mask_seq = mask_seq_auto->template mutable_data<T>();
273 math::Set<T, Context>(N*K,
static_cast<T
>(0),
274 mask_seq_auto->template mutable_data<T>(), &context_);
276 int seq_len = MatrixCompare_LT<T>(
277 Ag_dW_ptr->template mutable_data<T>(),
278 *thres.template data<T>(), mask_seq, N, K);
280 MaskMatrix_Inc<T, Context>(mask_seq,
281 dW->template mutable_data<T>(),
283 MaskMatrix_Inc<T, Context>(mask_seq,
284 W.template mutable_data<T>(),
286 MaskMatrix_Inc<T, Context>(mask_seq,
287 Mask.template mutable_data<T>(),
289 math::Set<T, Context>(N*K,
static_cast<T
>(0),
290 Ag_dW.template mutable_data<T>(),
295 Ag_dW.template mutable_data<T>(),
296 dW->template mutable_data<T>(),
300 if (bias_multiplier_.size() != M) {
303 bias_multiplier_.Resize(M);
304 math::Set<T, Context>(
305 M,
static_cast<T
>(1),
306 bias_multiplier_.template mutable_data<T>(),
310 math::Gemv<T, Context>(
311 CblasTrans, M, N, 1, dY.template data<T>(),
312 bias_multiplier_.template data<T>(), 0,
313 db->template mutable_data<T>(),
316 if (OutputSize() == 7) {
317 auto* dX = Output(6);
319 math::Gemm<T, Context, Engine>(
320 CblasNoTrans, CblasNoTrans, M, K, N, 1,
321 dY.template data<T>(), W.template data<T>(),
322 0, dX->template mutable_data<T>(),
337 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...