1 #include "caffe2/core/operator_gradient.h" 2 #include "caffe2/operators/elementwise_op.h" 3 #include "caffe2/utils/proto_utils.h" 7 const char* kBroadcastDoc = R
"DOC( 8 If necessary the right-hand-side argument will be broadcasted to match the 9 shape of left-hand-side argument. When broadcasting is specified, the second 10 tensor can either be of size 1 (a scalar value), or having its shape as a 11 contiguous subset of the first tensor's shape. The starting of the mutually 12 equal shape is specified by the argument "axis", and if it is not set, suffix 13 matching is assumed. 1-dim expansion doesn't work yet. 15 For example, the following tensor shapes are supported (with broadcast=1): 17 shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar 18 shape(A) = (2, 3, 4, 5), shape(B) = (5,) 19 shape(A) = (2, 3, 4, 5), shape(B) = (4, 5) 20 shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1 21 shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0 23 Argument `broadcast=1` needs to be passed to enable broadcasting. 26 std::function<void(OpSchema&)> MathDocGenerator(const char* name) {
27 return [=](OpSchema& schema) {
29 Performs element-wise binary {name} (with limited broadcast support). 31 ReplaceAll(doc, "{name}", name);
32 ReplaceAll(doc,
"{broadcast_doc}", kBroadcastDoc);
34 schema.Arg(
"broadcast",
"Pass 1 to enable broadcasting");
37 "If set, defines the broadcast dimensions. See doc for details.");
41 "First operand, should share the type with the second operand.");
45 "Second operand. With broadcasting can be of smaller size than A. " 46 "If broadcasting is disabled it should be of the same size.");
47 schema.Output(0,
"C",
"Result, has same dimensions and type as A");
54 .AllowInplace({{0, 0}, {1, 0}})
55 .CostInferenceFunction(PointwiseCostInference<1>)
56 .IdenticalTypeAndShapeOfInput(0)
57 .FillUsing(MathDocGenerator(
"addition"))
58 .InheritOnnxSchema(
"Add");
62 .AllowInplace({{0, 0}, {1, 0}})
63 .CostInferenceFunction(PointwiseCostInference<1>)
64 .IdenticalTypeAndShapeOfInput(0)
65 .FillUsing(MathDocGenerator(
"subtraction"))
66 .InheritOnnxSchema(
"Sub");
70 .AllowInplace({{0, 0}, {1, 0}})
71 .CostInferenceFunction(PointwiseCostInference<1>)
72 .IdenticalTypeAndShapeOfInput(0)
73 .FillUsing(MathDocGenerator(
"multiplication"))
74 .InheritOnnxSchema(
"Mul");
78 .AllowInplace({{0, 0}})
79 .CostInferenceFunction(PointwiseCostInference<1>)
80 .IdenticalTypeAndShapeOfInput(0)
81 .FillUsing(MathDocGenerator(
"division"))
82 .InheritOnnxSchema(
"Div");
83 OPERATOR_SCHEMA(DivGradient).NumInputs(3).NumOutputs(2).AllowInplace({{0, 0}});
85 OPERATOR_SCHEMA(SumReduceLike)
88 .IdenticalTypeAndShapeOfInput(0)
90 SumReduceLike operator takes 2 tensors as input. It performs reduce sum to the 91 first input so that the output looks like the second one. 92 It assumes that the first input 93 has more dimensions than the second, and the dimensions of the second input is 94 the contiguous subset of the dimensions of the first. 95 For example, the following tensor shapes are supported: 97 shape(A) = (2, 3, 4, 5), shape(B) = (4, 5) 98 shape(A) = (2, 3, 4, 5), shape(B) = (,), i.e. B is a scalar 99 shape(A) = (2, 3, 4, 5), shape(B) = (3, 4), with axis=1 100 shape(A) = (2, 3, 2, 5), shape(B) = (2), with axis=0 104 "If set, defines the starting dimension for reduction. Args `axis` and " 105 "`axis_str` cannot be used simultaneously.")
108 "If set, it could only be N or C or H or W. `order` arg should also be " 109 "provided. It defines the reduction dimensions on NCHW or NHWC. Args " 110 "`axis` and `axis_str` cannot be used simultaneously.")
111 .Arg(
"order",
"Either NHWC or HCWH")
115 "First operand, should share the type with the second operand.")
119 "Second operand. With broadcasting can be of smaller size than A. " 120 "If broadcasting is disabled it should be of the same size.")
121 .Output(0,
"C",
"Result, has same dimensions and type as B");
123 class GetAddGradient :
public GradientMakerBase {
124 using GradientMakerBase::GradientMakerBase;
125 vector<OperatorDef> GetGradientDefs()
override {
126 if (!ArgumentHelper::HasArgument(Def(),
"broadcast")) {
129 return vector<OperatorDef>();
133 return SingleGradientDef(
136 vector<string>{GO(0), I(1)},
137 vector<string>{GI(1)});
140 REGISTER_GRADIENT(Add, GetAddGradient);
144 class GetSubGradient :
public GradientMakerBase {
145 using GradientMakerBase::GradientMakerBase;
146 vector<OperatorDef> GetGradientDefs()
override {
147 if (!ArgumentHelper::HasArgument(Def(),
"broadcast")) {
149 return SingleGradientDef(
150 "Negative",
"", vector<string>{GO(0)}, vector<string>{GI(1)});
153 vector<OperatorDef> grad_ops;
154 grad_ops.push_back(CreateOperatorDef(
157 vector<string>{GO(0)},
158 vector<string>{GI(1) +
"_autogen_pre_red"}));
160 Argument axis, axis_str, order;
161 if (ArgumentHelper::HasArgument(Def(),
"axis")) {
162 axis = GetArgument(Def(),
"axis");
164 axis = MakeArgument<int>(
"axis", -1);
166 if (ArgumentHelper::HasArgument(Def(),
"axis_str")) {
167 axis_str = GetArgument(Def(),
"axis_str");
169 axis_str = MakeArgument<string>(
"axis_str",
"");
171 if (ArgumentHelper::HasArgument(Def(),
"order")) {
172 order = GetArgument(Def(),
"order");
174 order = MakeArgument<string>(
"order",
"NCHW");
176 grad_ops.push_back(CreateOperatorDef(
179 vector<string>{GI(1) +
"_autogen_pre_red", I(1)},
180 vector<string>{GI(1)},
181 vector<Argument>{axis, axis_str, order}));
187 bool CopyArguments()
const override {
191 REGISTER_GRADIENT(Sub, GetSubGradient);
193 class GetMulGradient :
public GradientMakerBase {
194 using GradientMakerBase::GradientMakerBase;
195 vector<OperatorDef> GetGradientDefs()
override {
197 Def().input(0) != Def().output(0) && Def().input(1) != Def().output(0),
198 "Gradient computation cannot be carried out if Mul uses in-place " 200 ProtoDebugString(Def()));
201 if (!ArgumentHelper::HasArgument(Def(),
"broadcast")) {
202 return vector<OperatorDef>{
204 "Mul",
"", vector<string>{GO(0), I(1)}, vector<string>{GI(0)}),
206 "Mul",
"", vector<string>{GO(0), I(0)}, vector<string>{GI(1)})};
208 Argument broadcast, axis, axis_str, order;
209 if (ArgumentHelper::HasArgument(Def(),
"broadcast")) {
210 broadcast = GetArgument(Def(),
"broadcast");
212 broadcast = MakeArgument<int>(
"broadcast", 0);
214 if (ArgumentHelper::HasArgument(Def(),
"axis")) {
215 axis = GetArgument(Def(),
"axis");
217 axis = MakeArgument<int>(
"axis", -1);
219 if (ArgumentHelper::HasArgument(Def(),
"axis_str")) {
220 axis_str = GetArgument(Def(),
"axis_str");
222 axis_str = MakeArgument<string>(
"axis_str",
"");
224 if (ArgumentHelper::HasArgument(Def(),
"order")) {
225 order = GetArgument(Def(),
"order");
227 order = MakeArgument<string>(
"order",
"NCHW");
230 vector<OperatorDef> grad_ops;
231 grad_ops.push_back(CreateOperatorDef(
234 vector<string>{GO(0), I(1)},
235 vector<string>{GI(0)},
236 vector<Argument>{broadcast, axis, axis_str, order}));
237 grad_ops.push_back(CreateOperatorDef(
239 "mul_gradient_2nd_op",
240 vector<string>{GO(0), I(0)},
241 vector<string>{GI(1) +
"_autogen_pre_red"}));
243 grad_ops.push_back(CreateOperatorDef(
245 "mul_with_broadcast_grad_3",
246 vector<string>{GI(1) +
"_autogen_pre_red", I(1)},
247 vector<string>{GI(1)},
248 vector<Argument>{axis, axis_str, order}));
255 bool CopyArguments()
const override {
259 REGISTER_GRADIENT(Mul, GetMulGradient);
261 class GetDivGradient :
public GradientMakerBase {
262 using GradientMakerBase::GradientMakerBase;
263 vector<OperatorDef> GetGradientDefs()
override {
265 !ArgumentHelper::HasArgument(Def(),
"broadcast"),
266 "Gradient not ready yet for Div with broadcasting.");
267 return SingleGradientDef(
270 vector<string>{I(1), O(0), GO(0)},
271 vector<string>{GI(0), GI(1)});
274 REGISTER_GRADIENT(Div, GetDivGradient);
276 std::function<void(OpSchema&)> ComparisonDocGenerator(
279 return [=](OpSchema& schema) {
281 Performs element-wise {desc} comparison `{name}` (with limited broadcast support). 282 {broadcast_doc})DOC"; 283 ReplaceAll(doc, "{name}", name);
284 ReplaceAll(doc,
"{desc}", desc);
285 ReplaceAll(doc,
"{broadcast_doc}", kBroadcastDoc);
287 schema.Arg(
"broadcast",
"Pass 1 to enable broadcasting");
290 "If set, defines the broadcast dimensions. See doc for details.");
294 "First operand, should share the type with the second operand.");
298 "Second operand. With broadcasting can be of smaller size than A. " 299 "If broadcasting is disabled it should be of the same size.");
300 schema.Output(0,
"C",
"Result, has same dimensions and A and type `bool`");
304 #define CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(name, symbol, desc) \ 305 OPERATOR_SCHEMA(name).NumInputs(2).NumOutputs(1).FillUsing( \ 306 ComparisonDocGenerator(symbol, desc)); \ 307 SHOULD_NOT_DO_GRADIENT(name) 309 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(LT,
"<",
"less than");
310 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(LE,
"<=",
"less or equal than");
311 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(GT,
">",
"greater than");
312 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(GE,
">=",
"greater or equal than");
313 CAFFE2_SCHEMA_FOR_BINARY_COMPARISON_OP(EQ,
"==",
"equality");
315 std::function<void(OpSchema&)> LogicalDocGenerator(
const char* name) {
316 return [=](OpSchema& schema) {
318 Performs element-wise logical operation `{name}` (with limited broadcast support). 319 Both input operands should be of type `bool`. 320 {broadcast_doc})DOC"; 321 ReplaceAll(doc, "{name}", name);
322 ReplaceAll(doc,
"{broadcast_doc}", kBroadcastDoc);
324 schema.Arg(
"broadcast",
"Pass 1 to enable broadcasting");
327 "If set, defines the broadcast dimensions. See doc for details.");
328 schema.Input(0,
"A",
"First operand.");
332 "Second operand. With broadcasting can be of smaller size than A. " 333 "If broadcasting is disabled it should be of the same size.");
334 schema.Output(0,
"C",
"Result, has same dimensions and A and type `bool`");
338 #define CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(name, symbol, onnx_schema) \ 339 OPERATOR_SCHEMA(name) \ 342 .AllowInplace({{0, 0}}) \ 343 .FillUsing(LogicalDocGenerator(symbol)) \ 344 .InheritOnnxSchema(onnx_schema); \ 345 SHOULD_NOT_DO_GRADIENT(name) 347 CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Or,
"or",
"Or");
348 CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(And,
"and",
"And");
349 CAFFE2_SCHEMA_FOR_BINARY_LOGICAL_OP(Xor,
"xor",
"Xor");
354 .SetDoc(R
"DOC(Performs element-wise negation.)DOC") 355 .Input(0, "X",
"Input tensor of type `bool`.")
356 .Output(0,
"Y",
"Output tensor of type `bool`.")
357 .InheritOnnxSchema(
"Not");
358 SHOULD_NOT_DO_GRADIENT(Not);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...