Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_gradient_op.cc
1 #include "caffe2/operators/conv_op.h"
2 #include "caffe2/operators/conv_op_impl.h"
3 #include "caffe2/operators/conv_pool_op_base.h"
4 
5 namespace caffe2 {
6 
7 REGISTER_CPU_OPERATOR(ConvGradient, ConvGradientOp<float, CPUContext>);
8 OPERATOR_SCHEMA(ConvGradient).NumInputs(2, 3).NumOutputs(1, 3);
9 
10 REGISTER_CPU_OPERATOR(Conv1DGradient, ConvGradientOp<float, CPUContext>);
11 OPERATOR_SCHEMA(Conv1DGradient).NumInputs(2, 3).NumOutputs(1, 3);
12 
13 REGISTER_CPU_OPERATOR(Conv2DGradient, ConvGradientOp<float, CPUContext>);
14 OPERATOR_SCHEMA(Conv2DGradient).NumInputs(2, 3).NumOutputs(1, 3);
15 
16 REGISTER_CPU_OPERATOR(Conv3DGradient, ConvGradientOp<float, CPUContext>);
17 OPERATOR_SCHEMA(Conv3DGradient).NumInputs(2, 3).NumOutputs(1, 3);
18 
20  using GradientMakerBase::GradientMakerBase;
21  vector<OperatorDef> GetGradientDefs() override {
22  CAFFE_ENFORCE(def_.input_size() == 3 || def_.input_size() == 2);
23 
24  ArgumentHelper argsHelper(def_);
25 
26  auto compute_dX = !argsHelper.GetSingleArgument<bool>("no_gradient_to_input", 0);
27 
28  if (def_.input_size() == 3) {
29  if (compute_dX) {
30  return SingleGradientDef(
31  def_.type() + "Gradient",
32  "",
33  vector<string>{I(0), I(1), GO(0)},
34  vector<string>{GI(1), GI(2), GI(0)});
35  } else {
36  return SingleGradientDef(
37  def_.type() + "Gradient",
38  "",
39  vector<string>{I(0), I(1), GO(0)},
40  vector<string>{GI(1), GI(2)});
41  }
42  } else {
43  if (compute_dX) {
44  return SingleGradientDef(
45  def_.type() + "Gradient",
46  "",
47  vector<string>{I(0), I(1), GO(0)},
48  vector<string>{GI(1), GI(0)},
49  vector<Argument>{MakeArgument<int>("no_bias", 1)});
50  } else {
51  return SingleGradientDef(
52  def_.type() + "Gradient",
53  "",
54  vector<string>{I(0), I(1), GO(0)},
55  vector<string>{GI(1)},
56  vector<Argument>{MakeArgument<int>("no_bias", 1)});
57  }
58  }
59  }
60 };
61 REGISTER_GRADIENT(Conv, GetConvGradient);
62 REGISTER_GRADIENT(Conv1D, GetConvGradient);
63 REGISTER_GRADIENT(Conv2D, GetConvGradient);
64 REGISTER_GRADIENT(Conv3D, GetConvGradient);
65 
66 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...