Caffe2 - C++ API
A deep learning, cross platform ML framework
segment_reduction_op.cc
1 #include "caffe2/operators/segment_reduction_op.h"
2 
3 namespace caffe2 {
4 
5 // registering 5 input gradient with main output
6 // gradient of SparseLengthsWeightedSum
7 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient)
8  .NumInputs(5)
9  .NumOutputs(2);
10 REGISTER_CPU_OPERATOR(
11  SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient,
12  AbstractLengthsWithMainInputGradientOp<
13  float,
14  int,
15  CPUContext,
16  WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
17  true /*SparseFused*/,
18  true /*GradientNeedIndices*/>);
19 
20 // registering 4 input version
21 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumGradient)
22  .NumInputs(4)
23  .NumOutputs(1);
24 REGISTER_CPU_OPERATOR(
25  SparseLengthsIndicesInGradientWeightedSumGradient,
26  AbstractLengthsGradientOp<
27  float,
28  int,
29  CPUContext,
30  WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
31  true /*GradientNeedIndices*/>);
32 
33 // registering 3 input version
34 // gradient of SparseLengthsSum
35 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientSumGradient)
36  .NumInputs(3)
37  .NumOutputs(1);
38 REGISTER_CPU_OPERATOR(
39  SparseLengthsIndicesInGradientSumGradient,
40  AbstractLengthsGradientOp<
41  float,
42  int,
43  CPUContext,
44  SumReducerDef::template ReducerGradient<float, CPUContext>,
45  true /*GradientNeedIndices*/>);
46 // gradient of LengthsSum
47 OPERATOR_SCHEMA(LengthsIndicesInGradientSumGradient).NumInputs(3).NumOutputs(1);
48 REGISTER_CPU_OPERATOR(
49  LengthsIndicesInGradientSumGradient,
50  AbstractLengthsGradientOp<
51  float,
52  int,
53  CPUContext,
54  SumReducerDef::template ReducerGradient<float, CPUContext>,
55  true /*GradientNeedIndices*/>);
56 
57 namespace {
58 
59 template <typename Def>
60 string FormatDoc() {
61  string doc = Def::doc;
62  ReplaceAll(doc, "{op}", Def::OpDef::name);
63  ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
64  return doc;
65 }
66 
67 // Helper macro when the main op is defined elsewhere, and we only need to
68 // define the schema, and the gradient op.
69 #define REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(...) \
70  OPERATOR_SCHEMA_STR( \
71  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name)) \
72  .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \
73  .NumOutputs(1) \
74  .SetDoc(FormatDoc<__VA_ARGS__>()) \
75  .Output(0, "OUTPUT", "Aggregated tensor") \
76  .FillUsing(__VA_ARGS__::PopulateSchema); \
77  REGISTER_CPU_OPERATOR_STR( \
78  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient", \
79  __VA_ARGS__::BackwardOp); \
80  OPERATOR_SCHEMA_STR( \
81  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient") \
82  .NumInputs(__VA_ARGS__::BackwardOp::kNumInputs) \
83  .NumOutputs(1); \
84  REGISTER_GRADIENT_STR( \
85  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
86  __VA_ARGS__::GetGradient)
87 
88 #define REGISTER_SEGMENT_DEF(...) \
89  REGISTER_CPU_OPERATOR_STR( \
90  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
91  __VA_ARGS__::ForwardOp); \
92  REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(__VA_ARGS__)
93 
94 REGISTER_SEGMENT_DEF(
95  AbstractSortedSegmentRangeDef<float, int, CPUContext, SumRangeReducerDef>);
96 REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
97  float,
98  int,
99  CPUContext,
100  LogSumExpRangeReducerDef>);
101 REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
102  float,
103  int,
104  CPUContext,
105  LogMeanExpRangeReducerDef>);
106 REGISTER_SEGMENT_DEF(
107  AbstractSortedSegmentRangeDef<float, int, CPUContext, MeanRangeReducerDef>);
108 REGISTER_SEGMENT_DEF(
109  AbstractSortedSegmentRangeDef<float, int, CPUContext, MaxRangeReducerDef>);
110 
111 #define REGISTER_REDUCER_WITH_OPS(reducer_def) \
112  REGISTER_SEGMENT_DEF( \
113  AbstractSortedSegmentDef<float, int, CPUContext, reducer_def>); \
114  REGISTER_SEGMENT_DEF( \
115  AbstractSparseSortedSegmentDef<float, int, CPUContext, reducer_def>); \
116  REGISTER_SEGMENT_DEF( \
117  AbstractUnsortedSegmentDef<float, int, CPUContext, reducer_def>); \
118  REGISTER_SEGMENT_DEF( \
119  AbstractSparseUnsortedSegmentDef<float, int, CPUContext, reducer_def>)
120 
121 #define REGISTER_REDUCER_WITH_LENGTH_OPS(reducer_def, GradientNeedIndices) \
122  REGISTER_SEGMENT_DEF(AbstractLengthsDef< \
123  float, \
124  int, \
125  CPUContext, \
126  reducer_def, \
127  GradientNeedIndices>)
128 
129 #define REGISTER_REDUCER_WITH_ALL_OPS(reducer_def) \
130  REGISTER_SEGMENT_DEF( \
131  AbstractReduceFrontDef<float, CPUContext, reducer_def>); \
132  REGISTER_REDUCER_WITH_OPS(reducer_def) \
133  REGISTER_REDUCER_WITH_LENGTH_OPS(reducer_def, false)
134 
135 REGISTER_REDUCER_WITH_OPS(SumReducerDef);
136 REGISTER_REDUCER_WITH_LENGTH_OPS(SumReducerDef, true);
137 
138 REGISTER_REDUCER_WITH_OPS(MeanReducerDef);
139 REGISTER_REDUCER_WITH_LENGTH_OPS(MeanReducerDef, false);
140 
141 REGISTER_REDUCER_WITH_ALL_OPS(WeightedSumReducerDef);
142 
143 // SparseLengths[Sum,WeightedSum,Mean] are now implemented separately,
144 // so we only rely to the historical implementation for the backward + schema.
145 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(AbstractSparseLengthsDef<
146  float,
147  int,
148  CPUContext,
149  SumReducerDef,
150  true /*GradientNeedIndices*/>)
151 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(AbstractSparseLengthsDef<
152  float,
153  int,
154  CPUContext,
155  WeightedSumReducerDef,
156  true /*GradientNeedIndices*/>)
157 
158 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(
159  AbstractSparseLengthsDef<float, int, CPUContext, MeanReducerDef>)
160 
161 // Auxiliary output gradients are currently implemented only for Lengths version
162 #define REGISTER_GRADIENT_WITH_MAIN_INPUT(...) \
163  REGISTER_CPU_OPERATOR_STR( \
164  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
165  "WithMainInputGradient", \
166  __VA_ARGS__::WithMainInputBackwardOp); \
167  OPERATOR_SCHEMA_STR( \
168  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
169  "WithMainInputGradient") \
170  .NumInputs(__VA_ARGS__::WithMainInputBackwardOp::kNumInputs) \
171  .NumOutputs(1, INT_MAX)
172 REGISTER_GRADIENT_WITH_MAIN_INPUT(
173  AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
174 REGISTER_GRADIENT_WITH_MAIN_INPUT(
175  AbstractSparseLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
176 
177 #define REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT(...) \
178  REGISTER_CPU_OPERATOR_STR( \
179  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
180  "WithMainInputAndForwardOutputGradient", \
181  __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp); \
182  OPERATOR_SCHEMA_STR( \
183  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
184  "WithMainInputAndForwardOutputGradient") \
185  .NumInputs( \
186  __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp::kNumInputs) \
187  .NumOutputs(1, INT_MAX)
188 
189 #define REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(...) \
190  OPERATOR_SCHEMA_STR( \
191  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name)) \
192  .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \
193  .NumOutputs(1) \
194  .SetDoc(FormatDoc<__VA_ARGS__>()) \
195  .Output(0, "OUTPUT", "Aggregated tensor") \
196  .FillUsing(__VA_ARGS__::PopulateSchema); \
197  REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT(__VA_ARGS__); \
198  REGISTER_GRADIENT_STR( \
199  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
200  __VA_ARGS__::GetGradient)
201 
202 // This implements and registers a length op with a gradient which requires
203 // the main input as well as the output of the forward output.
204 #define REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(...) \
205  REGISTER_CPU_OPERATOR_STR( \
206  string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
207  __VA_ARGS__::ForwardOp); \
208  REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(__VA_ARGS__)
209 
210 REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(
211  AbstractLengthsDef<float, int, CPUContext, MaxReducerDef>);
212 }
213 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...