Caffe2 - C++ API
A deep learning, cross platform ML framework
minmax_gradient_ops.cc
1 #include "caffe2/operators/minmax_ops.h"
2 
3 namespace caffe2 {
4 
5 REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
6 REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
7 
8 OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
9 OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
10 
11 template <typename T, class Context>
12 bool SelectGradientOpBase<T, Context>::RunOnDevice() {
13  auto& output = Input(0);
14  auto& grad_output = Input(1);
15  const int kInputStartOffset = 2;
16 
17  const T* data = output.template data<T>();
18  ConstEigenArrayMap<T> output_array(
19  output.template data<T>(), 1, output.size());
20  ConstEigenArrayMap<T> grad_out_array(
21  grad_output.template data<T>(), 1, grad_output.size());
22 
23  for (int i = 0; i < OutputSize(); i++) {
24  auto& input = Input(i + kInputStartOffset);
25  ConstEigenArrayMap<T> input_array(
26  input.template data<T>(), 1, input.size());
27 
28  auto* grad_input = Output(i);
29  grad_input->ResizeLike(input);
30  EigenArrayMap<T> grad_in_array(
31  grad_input->template mutable_data<T>(), 1, grad_input->size());
32  grad_in_array = grad_out_array *
33  input_array.cwiseEqual(output_array).template cast<T>();
34  }
35  return true;
36 }
37 
39  using GradientMakerBase::GradientMakerBase;
40  vector<OperatorDef> GetGradientDefs() override {
41  auto gradInputs = vector<string>();
42  auto inputs = vector<string>{O(0), GO(0)};
43  for (int i = 0; i < def_.input_size(); i++) {
44  gradInputs.push_back(GI(i));
45  inputs.push_back(I(i));
46  }
47  return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
48  }
49 };
50 REGISTER_GRADIENT(Max, GetMaxGradient);
51 
53  using GradientMakerBase::GradientMakerBase;
54  vector<OperatorDef> GetGradientDefs() override {
55  auto gradInputs = vector<string>();
56  auto inputs = vector<string>{O(0), GO(0)};
57  for (int i = 0; i < def_.input_size(); i++) {
58  gradInputs.push_back(GI(i));
59  inputs.push_back(I(i));
60  }
61  return SingleGradientDef("MinGradient", "", inputs, gradInputs);
62  }
63 };
64 REGISTER_GRADIENT(Min, GetMinGradient);
65 
66 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...