Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_op_schema.cc
1 #include "caffe2/core/operator_gradient.h"
2 #include "caffe2/operators/elementwise_op.h"
3 #include "caffe2/utils/proto_utils.h"
4 
5 namespace caffe2 {
6 
7 const char* kBroadcastDoc = R"DOC(
8 If necessary the right-hand-side argument will be broadcasted to match the
9 shape of left-hand-side argument. When broadcasting is specified, the second
10 tensor can either be of size 1 (a scalar value), or having its shape as a
11 contiguous subset of the first tensor's shape. The starting of the mutually
12 equal shape is specified by the argument "axis", and if it is not set, suffix
13 matching is assumed. 1-dim expansion doesn't work yet.
14 
15 For example, the following tensor shapes are supported (with broadcast=1):
16 
17  shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar
18  shape(A) = (2, 3, 4, 5), shape(B) = (5,)
19  shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)
20  shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1
21  shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0
22 
23 Argument `broadcast=1` needs to be passed to enable broadcasting.
24 )DOC";
25 
26 std::function<void(OpSchema&)> MathDocGenerator(const char* name) {
27  return [=](OpSchema& schema) {
28  string doc = R"DOC(
29 Performs element-wise binary {name} (with limited broadcast support).
30 {broadcast_doc})DOC";
31  ReplaceAll(doc, "{name}", name);
32  ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
33  schema.SetDoc(doc);
34  schema.Arg("broadcast", "Pass 1 to enable broadcasting");
35  schema.Arg(
36  "axis",
37  "If set, defines the broadcast dimensions. See doc for details.");
38  schema.Input(
39  0,
40  "A",
41  "First operand, should share the type with the second operand.");
42  schema.Input(
43  1,
44  "B",
45  "Second operand. With broadcasting can be of smaller size than A. "
46  "If broadcasting is disabled it should be of the same size.");
47  schema.Output(0, "C", "Result, has same dimensions and type as A");
48  };
49 }
50 
51 OPERATOR_SCHEMA(Add)
52  .NumInputs(2)
53  .NumOutputs(1)
54  .AllowInplace({{0, 0}, {1, 0}})
55  .CostInferenceFunction(PointwiseCostInference<1>)
56  .IdenticalTypeAndShapeOfInput(0)
57  .FillUsing(MathDocGenerator("addition"))
58  .InheritOnnxSchema("Add");
59 OPERATOR_SCHEMA(Sub)
60  .NumInputs(2)
61  .NumOutputs(1)
62  .AllowInplace({{0, 0}, {1, 0}})
63  .CostInferenceFunction(PointwiseCostInference<1>)
64  .IdenticalTypeAndShapeOfInput(0)
65  .FillUsing(MathDocGenerator("subtraction"))
66  .InheritOnnxSchema("Sub");
67 OPERATOR_SCHEMA(Mul)
68  .NumInputs(2)
69  .NumOutputs(1)
70  .AllowInplace({{0, 0}, {1, 0}})
71  .CostInferenceFunction(PointwiseCostInference<1>)
72  .IdenticalTypeAndShapeOfInput(0)
73  .FillUsing(MathDocGenerator("multiplication"))
74  .InheritOnnxSchema("Mul");
75 OPERATOR_SCHEMA(Div)
76  .NumInputs(2)
77  .NumOutputs(1)
78  .AllowInplace({{0, 0}})
79  .CostInferenceFunction(PointwiseCostInference<1>)
80  .IdenticalTypeAndShapeOfInput(0)
81  .FillUsing(MathDocGenerator("division"))
82  .InheritOnnxSchema("Div");
83 OPERATOR_SCHEMA(DivGradient).NumInputs(3).NumOutputs(2).AllowInplace({{0, 0}});
84 
85 OPERATOR_SCHEMA(SumReduceLike)
86  .NumInputs(2)
87  .NumOutputs(1)
88  .IdenticalTypeAndShapeOfInput(0)
89  .SetDoc(R"DOC(
90 SumReduceLike operator takes 2 tensors as input. It performs reduce sum to the
91 first input so that the output looks like the second one.
92 It assumes that the first input
93 has more dimensions than the second, and the dimensions of the second input is
94 the contiguous subset of the dimensions of the first.
95 For example, the following tensor shapes are supported:
96 
97  shape(A) = (2, 3, 4, 5), shape(B) = (4, 5)
98  shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar
99  shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1
100  shape(A) = (2, 3, 2, 5), shape(B) = (2), with axis=0
101  )DOC")
102  .Arg(
103  "axis",
104  "If set, defines the starting dimension for reduction. Args `axis` and "
105  "`axis_str` cannot be used simultaneously.")
106  .Arg(
107  "axis_str",
108  "If set, it could only be N or C or H or W. `order` arg should also be "
109  "provided. It defines the reduction dimensions on NCHW or NHWC. Args "
110  "`axis` and `axis_str` cannot be used simultaneously.")
111  .Arg("order", "Either NHWC or HCWH")
112  .Input(
113  0,
114  "A",
115  "First operand, should share the type with the second operand.")
116  .Input(
117  1,
118  "B",
119  "Second operand. With broadcasting can be of smaller size than A. "
120  "If broadcasting is disabled it should be of the same size.")
121  .Output(0, "C", "Result, has same dimensions and type as B");
122 
123 class GetAddGradient : public GradientMakerBase {
124  using GradientMakerBase::GradientMakerBase;
125  vector<OperatorDef> GetGradientDefs() override {
126  if (!ArgumentHelper::HasArgument(Def(), "broadcast")) {
127  SetDense(0, GO(0));
128  SetDense(1, GO(0));
129  return vector<OperatorDef>();
130  }
131  SetDense(0, GO(0));
132 
133  return SingleGradientDef(
134  "SumReduceLike",
135  "",
136  vector<string>{GO(0), I(1)},
137  vector<string>{GI(1)});
138  }
139 };
140 REGISTER_GRADIENT(Add, GetAddGradient);
141 
142 // TODO(jiayq): Although we have Sub gradient implemented, we are still missing
143 // the Negative unary operator to be implemented.
144 class GetSubGradient : public GradientMakerBase {
145  using GradientMakerBase::GradientMakerBase;
146  vector<OperatorDef> GetGradientDefs() override {
147  if (!ArgumentHelper::HasArgument(Def(), "broadcast")) {
148  SetDense(0, GO(0));
149  return SingleGradientDef(
150  "Negative", "", vector<string>{GO(0)}, vector<string>{GI(1)});
151  } else {
152  SetDense(0, GO(0));
153  vector<OperatorDef> grad_ops;
154  grad_ops.push_back(CreateOperatorDef(
155  "Negative",
156  "",
157  vector<string>{GO(0)},
158  vector<string>{GI(1) + "_autogen_pre_red"}));
159 
160  Argument axis, axis_str, order;
161  if (ArgumentHelper::HasArgument(Def(), "axis")) {
162  axis = GetArgument(Def(), "axis");
163  } else {
164  axis = MakeArgument<int>("axis", -1);
165  }
166  if (ArgumentHelper::HasArgument(Def(), "axis_str")) {
167  axis_str = GetArgument(Def(), "axis_str");
168  } else {
169  axis_str = MakeArgument<string>("axis_str", "");
170  }
171  if (ArgumentHelper::HasArgument(Def(), "order")) {
172  order = GetArgument(Def(), "order");
173  } else {
174  order = MakeArgument<string>("order", "NCHW");
175  }
176  grad_ops.push_back(CreateOperatorDef(
177  "SumReduceLike",
178  "",
179  vector<string>{GI(1) + "_autogen_pre_red", I(1)},
180  vector<string>{GI(1)},
181  vector<Argument>{axis, axis_str, order}));
182 
183  return grad_ops;
184  }
185  }
186  // Make sure the broadcast argument is not copied over.
187  bool CopyArguments() const override {
188  return false;
189  }
190 };
191 REGISTER_GRADIENT(Sub, GetSubGradient);
192 
193 class GetMulGradient : public GradientMakerBase {
194  using GradientMakerBase::GradientMakerBase;
195  vector<OperatorDef> GetGradientDefs() override {
196  CAFFE_ENFORCE(
197  Def().input(0) != Def().output(0) && Def().input(1) != Def().output(0),
198  "Gradient computation cannot be carried out if Mul uses in-place "
199  "computation: ",
200  ProtoDebugString(Def()));
201  if (!ArgumentHelper::HasArgument(Def(), "broadcast")) {
202  return vector<OperatorDef>{
203  CreateOperatorDef(
204  "Mul", "", vector<string>{GO(0), I(1)}, vector<string>{GI(0)}),
205  CreateOperatorDef(
206  "Mul", "", vector<string>{GO(0), I(0)}, vector<string>{GI(1)})};
207  } else {
208  Argument broadcast, axis, axis_str, order;
209  if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
210  broadcast = GetArgument(Def(), "broadcast");
211  } else {
212  broadcast = MakeArgument<int>("broadcast", 0);
213  }
214  if (ArgumentHelper::HasArgument(Def(), "axis")) {
215  axis = GetArgument(Def(), "axis");
216  } else {
217  axis = MakeArgument<int>("axis", -1);
218  }
219  if (ArgumentHelper::HasArgument(Def(), "axis_str")) {
220  axis_str = GetArgument(Def(), "axis_str");
221  } else {
222  axis_str = MakeArgument<string>("axis_str", "");
223  }
224  if (ArgumentHelper::HasArgument(Def(), "order")) {
225  order = GetArgument(Def(), "order");
226  } else {
227  order = MakeArgument<string>("order", "NCHW");
228  }
229 
230  vector<OperatorDef> grad_ops;
231  grad_ops.push_back(CreateOperatorDef(
232  "Mul",
233  "mul_grad_1st_op",
234  vector<string>{GO(0), I(1)},
235  vector<string>{GI(0)},
236  vector<Argument>{broadcast, axis, axis_str, order}));
237  grad_ops.push_back(CreateOperatorDef(
238  "Mul",
239  "mul_gradient_2nd_op",
240  vector<string>{GO(0), I(0)},
241  vector<string>{GI(1) + "_autogen_pre_red"}));
242 
243  grad_ops.push_back(CreateOperatorDef(
244  "SumReduceLike",
245  "mul_with_broadcast_grad_3",
246  vector<string>{GI(1) + "_autogen_pre_red", I(1)},
247  vector<string>{GI(1)},
248  vector<Argument>{axis, axis_str, order}));
249 
250  return grad_ops;
251  }
252  }
253 
254  // Make sure the broadcast argument is not copied over.
255  bool CopyArguments() const override {
256  return false;
257  }
258 };
259 REGISTER_GRADIENT(Mul, GetMulGradient);
260 
261 class GetDivGradient : public GradientMakerBase {
262  using GradientMakerBase::GradientMakerBase;
263  vector<OperatorDef> GetGradientDefs() override {
264  CAFFE_ENFORCE(
265  !ArgumentHelper::HasArgument(Def(), "broadcast"),
266  "Gradient not ready yet for Div with broadcasting.");
267  return SingleGradientDef(
268  "DivGradient",
269  "",
270  vector<string>{I(1), O(0), GO(0)},
271  vector<string>{GI(0), GI(1)});
272  }
273 };
274 REGISTER_GRADIENT(Div, GetDivGradient);
275 
276 std::function<void(OpSchema&)> ComparisonDocGenerator(
277  const char* name,
278  const char* desc) {
279  return [=](OpSchema& schema) {
280  string doc = R"DOC(
281 Performs element-wise {desc} comparison `{name}` (with limited broadcast support).
282 {broadcast_doc})DOC";
283  ReplaceAll(doc, "{name}", name);
284  ReplaceAll(doc, "{desc}", desc);
285  ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
286  schema.SetDoc(doc);
287  schema.Arg("broadcast", "Pass 1 to enable broadcasting");
288  schema.Arg(
289  "axis",
290  "If set, defines the broadcast dimensions. See doc for details.");
291  schema.Input(
292  0,
293  "A",
294  "First operand, should share the type with the second operand.");
295  schema.Input(
296  1,
297  "B",
298  "Second operand. With broadcasting can be of smaller size than A. "
299  "If broadcasting is disabled it should be of the same size.");
300  schema.Output(0, "C", "Result, has same dimensions and A and type `bool`");
301  };
302 }
303 
304 #define CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(name, symbol, desc) \
305  OPERATOR_SCHEMA(name).NumInputs(2).NumOutputs(1).FillUsing( \
306  ComparisonDocGenerator(symbol, desc)); \
307  SHOULD_NOT_DO_GRADIENT(name)
308 
309 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(LT, "<", "less than");
310 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(LE, "<=", "less or equal than");
311 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(GT, ">", "greater than");
312 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(GE, ">=", "greater or equal than");
313 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(EQ, "==", "equality");
314 
315 std::function<void(OpSchema&)> LogicalDocGenerator(const char* name) {
316  return [=](OpSchema& schema) {
317  string doc = R"DOC(
318 Performs element-wise logical operation `{name}` (with limited broadcast support).
319 Both input operands should be of type `bool`.
320 {broadcast_doc})DOC";
321  ReplaceAll(doc, "{name}", name);
322  ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
323  schema.SetDoc(doc);
324  schema.Arg("broadcast", "Pass 1 to enable broadcasting");
325  schema.Arg(
326  "axis",
327  "If set, defines the broadcast dimensions. See doc for details.");
328  schema.Input(0, "A", "First operand.");
329  schema.Input(
330  1,
331  "B",
332  "Second operand. With broadcasting can be of smaller size than A. "
333  "If broadcasting is disabled it should be of the same size.");
334  schema.Output(0, "C", "Result, has same dimensions and A and type `bool`");
335  };
336 }
337 
338 #define CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(name, symbol, onnx_schema) \
339  OPERATOR_SCHEMA(name) \
340  .NumInputs(2) \
341  .NumOutputs(1) \
342  .AllowInplace({{0, 0}}) \
343  .FillUsing(LogicalDocGenerator(symbol)) \
344  .InheritOnnxSchema(onnx_schema); \
345  SHOULD_NOT_DO_GRADIENT(name)
346 
347 CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Or, "or", "Or");
348 CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(And, "and", "And");
349 CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Xor, "xor", "Xor");
350 
351 OPERATOR_SCHEMA(Not)
352  .NumInputs(1)
353  .NumOutputs(1)
354  .SetDoc(R"DOC(Performs element-wise negation.)DOC")
355  .Input(0, "X", "Input tensor of type `bool`.")
356  .Output(0, "Y", "Output tensor of type `bool`.")
357  .InheritOnnxSchema("Not");
358 SHOULD_NOT_DO_GRADIENT(Not);
359 
360 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...