Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_prune.cc
1 #include "caffe2/experiments/operators/fully_connected_op_prune.h"
2 
3 namespace caffe2 {
4 namespace {
5 
6 REGISTER_CPU_OPERATOR(FC_Prune, FullyConnectedOpPrune<float, CPUContext>);
7 REGISTER_CPU_OPERATOR(FCGradient_Prune,
8  FullyConnectedPruneGradientOp<float, CPUContext>);
9 /* 8 Inputs:
10  * X W Mask bias Ag_dw Mask_seq thres comp_lb
11  * */
12 OPERATOR_SCHEMA(FC_Prune).NumInputs(8).NumOutputs(1, 2);
13 OPERATOR_SCHEMA(FCGradient_Prune).NumInputs(8).NumOutputs(6, 7)
14  .AllowInplace({{1, 2}, {2, 3}, {4, 4}, {5, 5}});
15 
16 class GetFCPruneGradient : public GradientMakerBase {
17  using GradientMakerBase::GradientMakerBase;
18  vector<OperatorDef> GetGradientDefs() override {
19  CAFFE_ENFORCE_EQ(def_.input_size(), 8);
20  return SingleGradientDef(
21  "FCGradient_Prune", "",
22  vector<string>{I(0), I(1), I(2), GO(0), I(4), I(5), I(6), I(7)},
23  vector<string>{GI(1), GI(3), I(1), I(2), I(4), I(5), GI(0)});
24  }
25 };
26 REGISTER_GRADIENT(FC_Prune, GetFCPruneGradient);
27 } // namespace
28 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...