1 #include "caffe2/operators/segment_reduction_op.h" 7 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient)
10 REGISTER_CPU_OPERATOR(
11 SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient,
12 AbstractLengthsWithMainInputGradientOp<
16 WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
21 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumGradient)
24 REGISTER_CPU_OPERATOR(
25 SparseLengthsIndicesInGradientWeightedSumGradient,
26 AbstractLengthsGradientOp<
30 WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
35 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientSumGradient)
38 REGISTER_CPU_OPERATOR(
39 SparseLengthsIndicesInGradientSumGradient,
40 AbstractLengthsGradientOp<
44 SumReducerDef::template ReducerGradient<float, CPUContext>,
47 OPERATOR_SCHEMA(LengthsIndicesInGradientSumGradient).NumInputs(3).NumOutputs(1);
48 REGISTER_CPU_OPERATOR(
49 LengthsIndicesInGradientSumGradient,
50 AbstractLengthsGradientOp<
54 SumReducerDef::template ReducerGradient<float, CPUContext>,
59 template <
typename Def>
61 string doc = Def::doc;
62 ReplaceAll(doc,
"{op}", Def::OpDef::name);
63 ReplaceAll(doc,
"{op_doc}", Def::OpDef::doc);
69 #define REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(...) \ 70 OPERATOR_SCHEMA_STR( \ 71 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name)) \ 72 .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \ 74 .SetDoc(FormatDoc<__VA_ARGS__>()) \ 75 .Output(0, "OUTPUT", "Aggregated tensor") \ 76 .FillUsing(__VA_ARGS__::PopulateSchema); \ 77 REGISTER_CPU_OPERATOR_STR( \ 78 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient", \ 79 __VA_ARGS__::BackwardOp); \ 80 OPERATOR_SCHEMA_STR( \ 81 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient") \ 82 .NumInputs(__VA_ARGS__::BackwardOp::kNumInputs) \ 84 REGISTER_GRADIENT_STR( \ 85 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \ 86 __VA_ARGS__::GetGradient) 88 #define REGISTER_SEGMENT_DEF(...) \ 89 REGISTER_CPU_OPERATOR_STR( \ 90 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \ 91 __VA_ARGS__::ForwardOp); \ 92 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(__VA_ARGS__) 95 AbstractSortedSegmentRangeDef<float, int, CPUContext, SumRangeReducerDef>);
96 REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
100 LogSumExpRangeReducerDef>);
101 REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
105 LogMeanExpRangeReducerDef>);
106 REGISTER_SEGMENT_DEF(
107 AbstractSortedSegmentRangeDef<float, int, CPUContext, MeanRangeReducerDef>);
108 REGISTER_SEGMENT_DEF(
109 AbstractSortedSegmentRangeDef<float, int, CPUContext, MaxRangeReducerDef>);
111 #define REGISTER_REDUCER_WITH_OPS(reducer_def) \ 112 REGISTER_SEGMENT_DEF( \ 113 AbstractSortedSegmentDef<float, int, CPUContext, reducer_def>); \ 114 REGISTER_SEGMENT_DEF( \ 115 AbstractSparseSortedSegmentDef<float, int, CPUContext, reducer_def>); \ 116 REGISTER_SEGMENT_DEF( \ 117 AbstractUnsortedSegmentDef<float, int, CPUContext, reducer_def>); \ 118 REGISTER_SEGMENT_DEF( \ 119 AbstractSparseUnsortedSegmentDef<float, int, CPUContext, reducer_def>) 121 #define REGISTER_REDUCER_WITH_LENGTH_OPS(reducer_def, GradientNeedIndices) \ 122 REGISTER_SEGMENT_DEF(AbstractLengthsDef< \ 127 GradientNeedIndices>) 129 #define REGISTER_REDUCER_WITH_ALL_OPS(reducer_def) \ 130 REGISTER_SEGMENT_DEF( \ 131 AbstractReduceFrontDef<float, CPUContext, reducer_def>); \ 132 REGISTER_REDUCER_WITH_OPS(reducer_def) \ 133 REGISTER_REDUCER_WITH_LENGTH_OPS(reducer_def, false) 135 REGISTER_REDUCER_WITH_OPS(SumReducerDef);
136 REGISTER_REDUCER_WITH_LENGTH_OPS(SumReducerDef,
true);
138 REGISTER_REDUCER_WITH_OPS(MeanReducerDef);
139 REGISTER_REDUCER_WITH_LENGTH_OPS(MeanReducerDef,
false);
141 REGISTER_REDUCER_WITH_ALL_OPS(WeightedSumReducerDef);
145 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(AbstractSparseLengthsDef<
151 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(AbstractSparseLengthsDef<
155 WeightedSumReducerDef,
158 REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(
159 AbstractSparseLengthsDef<
float,
int, CPUContext, MeanReducerDef>)
162 #define REGISTER_GRADIENT_WITH_MAIN_INPUT(...) \ 163 REGISTER_CPU_OPERATOR_STR( \ 164 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \ 165 "WithMainInputGradient", \ 166 __VA_ARGS__::WithMainInputBackwardOp); \ 167 OPERATOR_SCHEMA_STR( \ 168 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \ 169 "WithMainInputGradient") \ 170 .NumInputs(__VA_ARGS__::WithMainInputBackwardOp::kNumInputs) \ 171 .NumOutputs(1, INT_MAX) 172 REGISTER_GRADIENT_WITH_MAIN_INPUT(
173 AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
174 REGISTER_GRADIENT_WITH_MAIN_INPUT(
175 AbstractSparseLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
177 #define REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT(...) \ 178 REGISTER_CPU_OPERATOR_STR( \ 179 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \ 180 "WithMainInputAndForwardOutputGradient", \ 181 __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp); \ 182 OPERATOR_SCHEMA_STR( \ 183 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \ 184 "WithMainInputAndForwardOutputGradient") \ 186 __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp::kNumInputs) \ 187 .NumOutputs(1, INT_MAX) 189 #define REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(...) \ 190 OPERATOR_SCHEMA_STR( \ 191 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name)) \ 192 .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \ 194 .SetDoc(FormatDoc<__VA_ARGS__>()) \ 195 .Output(0, "OUTPUT", "Aggregated tensor") \ 196 .FillUsing(__VA_ARGS__::PopulateSchema); \ 197 REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT(__VA_ARGS__); \ 198 REGISTER_GRADIENT_STR( \ 199 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \ 200 __VA_ARGS__::GetGradient) 204 #define REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(...) \ 205 REGISTER_CPU_OPERATOR_STR( \ 206 string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \ 207 __VA_ARGS__::ForwardOp); \ 208 REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(__VA_ARGS__) 210 REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(
211 AbstractLengthsDef<float, int, CPUContext, MaxReducerDef>);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...