1 #ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_ 2 #define CAFFE2_CORE_OPERATOR_GRADIENT_H_ 4 #include "caffe2/core/operator_schema.h" 5 #include "caffe2/core/registry.h" 6 #include "caffe2/proto/caffe2.pb.h" 7 #include "caffe2/utils/proto_utils.h" 22 inline bool IsDense()
const {
23 return (dense_.size() != 0);
25 inline bool IsSparse()
const {
26 return (indices_.size() != 0 || values_.size() != 0);
28 inline bool IsEmpty()
const {
29 return (!IsDense() && !IsSparse());
37 vector<OperatorDef> ops_;
38 vector<GradientWrapper> g_input_;
42 const vector<OperatorDef>& ops,
43 const vector<GradientWrapper>& v)
44 : ops_(ops), g_input_(v) {}
50 const OperatorDef& def,
51 const vector<GradientWrapper>& g_output)
52 : def_(def), g_output_(g_output), g_input_(def.input_size()){};
54 virtual bool CopyDeviceOption()
const {
57 virtual bool CopyEngine()
const {
60 virtual bool CopyArguments()
const {
64 virtual void VerifyOp()
const {
65 auto* schema = OpSchemaRegistry::Schema(def_.type());
69 "(GradientMaker) Operator def did not pass schema checking: ",
70 ProtoDebugString(def_));
87 vector<OperatorDef> new_defs = GetGradientDefs();
88 for (
auto& opdef : new_defs) {
89 opdef.set_is_gradient_op(
true);
94 const OperatorDef& Def()
const {
99 virtual vector<OperatorDef> GetGradientDefs() {
100 CAFFE_NOT_IMPLEMENTED;
109 string I(
const int i) {
110 CAFFE_ENFORCE((i >= 0) && (i < def_.input().size()));
111 return def_.input(i);
113 string O(
const int i) {
114 CAFFE_ENFORCE((i >= 0) && (i < def_.output().size()));
115 return def_.output(i);
117 string GI(
const int i) {
119 !g_input_.at(i).IsSparse(),
122 " already set to sparse.");
123 g_input_.at(i).dense_ = GradientName(def_.input(i));
124 return GradientName(def_.input(i));
126 string GI_I(
const int i) {
128 !g_input_.at(i).IsDense(),
131 " already set to dense.");
132 g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i));
133 return GradientSliceIndices(def_.input(i));
135 string GI_V(
const int i) {
137 !g_input_.at(i).IsDense(),
140 " already set to dense.");
141 g_input_.at(i).values_ = GradientSliceValues(def_.input(i));
142 return GradientSliceValues(def_.input(i));
144 string GO(
const int i) {
146 g_output_.at(i).IsDense(),
147 "Gradient of output ",
149 (g_output_.at(i).IsSparse() ?
" is sparse (expected dense)." 150 :
" is not provided!"));
151 return g_output_.at(i).dense_;
153 string GO_I(
const int i) {
155 g_output_.at(i).IsSparse(),
156 "Gradient of output ",
158 (g_output_.at(i).IsDense() ?
" is dense (expected sparse)." 159 :
" is not provided!"));
160 return g_output_.at(i).indices_;
162 string GO_V(
const int i) {
164 g_output_.at(i).IsSparse(),
165 "Gradient of output ",
167 (g_output_.at(i).IsDense() ?
" is dense (expected sparse)." 168 :
" is not provided!"));
169 return g_output_.at(i).values_;
172 return g_output_.at(i);
176 void SetDense(
const int i,
const string& name) {
178 !g_input_.at(i).IsSparse(),
181 " already set to sparse.");
182 g_input_.at(i).dense_ = name;
184 void SetSparse(
const int i,
const string& indices,
const string& values) {
186 !g_input_.at(i).IsDense(),
189 " already set to dense.");
190 g_input_.at(i).indices_ = indices;
191 g_input_.at(i).values_ = values;
198 template <
class... Args>
200 return vector<OperatorDef>{CreateOperatorDef(args...)};
209 CaffeMap<string, string> m;
210 for (
auto& out : op.output()) {
211 if (IsGradientBlob(out)) {
212 m[out] = out.substr(0, out.length() - 5);
221 static string GradientName(
const string& name) {
222 return name +
"_grad";
225 static bool IsGradientBlob(
const string& name) {
226 return name.length() > 5 && name.find(
"_grad") == name.length() - 5;
229 static string GradientNameToParam(
const string& name) {
230 CHECK(IsGradientBlob(name));
231 return name.substr(0, name.length() - 5);
234 static string GradientSliceIndices(
const string& name) {
235 return name +
"_grad_indices";
238 static string GradientSliceValues(
const string& name) {
239 return name +
"_grad_values";
245 const OperatorDef& def_;
246 const vector<GradientWrapper>& g_output_;
247 vector<GradientWrapper> g_input_;
260 using GradientMakerBase::GradientMakerBase;
261 vector<OperatorDef> GetGradientDefs()
override {
262 return vector<OperatorDef>();
273 using GradientMakerBase::GradientMakerBase;
276 false,
"One should not call gradient for operator ", def_.type(),
".");
288 using GradientMakerBase::GradientMakerBase;
294 " should have a gradient but is not implemented yet.");
298 CAFFE_DECLARE_REGISTRY(
302 const vector<GradientWrapper>&);
304 #define REGISTER_GRADIENT(name, ...) \ 305 CAFFE_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) 306 #define REGISTER_GRADIENT_STR(str_name, ...) \ 307 CAFFE_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__) 310 #define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient) 315 #define SHOULD_NOT_DO_GRADIENT(name) \ 316 REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled) 318 #define GRADIENT_NOT_IMPLEMENTED_YET(name) \ 319 REGISTER_GRADIENT(name, GradientNotImplementedYet) 325 const OperatorDef& def,
326 const vector<GradientWrapper>& g_output);
330 #endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_ A helper class to indicate that the gradient mechanism is not ready.
static CaffeMap< string, string > MatchGradsToParams(const OperatorDef &op)
Returns map that returns the parameters that the gradients are for.
A helper class to indicate that the operator should have no gradient.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
virtual GradientOpsMeta Get()
Returns the gradient ops meta.
GradientOpsMeta GetGradientForOp(const OperatorDef &def, const vector< GradientWrapper > &g_output)
Gets the GradientOpsMeta for the given operator def.
GradientOpsMeta Get() override
Returns the gradient ops meta.
A helper class to indicate that the operator does not need gradient computation.
GradientOpsMeta Get() override
Returns the gradient ops meta.