Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout_op.cc
1 #include "caffe2/operators/dropout_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool DropoutOp<float, CPUContext>::RunOnDevice() {
7  auto& X = Input(0);
8  auto* Y = Output(0);
9  Y->Resize(X.dims());
10  if (is_test_) {
11  if (Y != &X) {
12  context_.Copy<float, CPUContext, CPUContext>(
13  X.size(), X.data<float>(), Y->mutable_data<float>());
14  }
15  return true;
16  } else {
17  float scale = 1. / (1. - ratio_);
18  // mask=true means keep, and mask=false means not keep, so we will
19  // generate probability depending on 1-ratio.
20  std::bernoulli_distribution dist(1. - ratio_);
21  const float* Xdata = X.data<float>();
22  float* Ydata = Y->mutable_data<float>();
23  auto mask = Output(1);
24  mask->Resize(X.dims());
25  bool* mask_data = mask->mutable_data<bool>();
26  auto& gen = context_.RandGenerator();
27  for (int i = 0; i < X.size(); ++i) {
28  mask_data[i] = dist(gen);
29  Ydata[i] = Xdata[i] * scale * mask_data[i];
30  }
31  return true;
32  }
33 }
34 
35 template <>
36 bool DropoutGradientOp<float, CPUContext>::RunOnDevice() {
37  auto& dY = Input(0);
38  auto* dX = Output(0);
39  dX->Resize(dY.dims());
40  if (is_test_) {
41  if (dX != &dY) {
42  context_.Copy<float, CPUContext, CPUContext>(
43  dY.size(), dY.data<float>(), dX->mutable_data<float>());
44  }
45  return true;
46  } else {
47  auto& mask = Input(1);
48  CAFFE_ENFORCE_EQ(dY.size(), mask.size());
49  const float* dYdata = dY.data<float>();
50  const bool* mask_data = mask.data<bool>();
51  float* dXdata = dX->mutable_data<float>();
52  float scale = 1. / (1. - ratio_);
53  for (int i = 0; i < dY.size(); ++i) {
54  dXdata[i] = dYdata[i] * mask_data[i] * scale;
55  }
56  return true;
57  }
58 }
59 
60 REGISTER_CPU_OPERATOR(Dropout, DropoutOp<float, CPUContext>);
61 REGISTER_CPU_OPERATOR(DropoutGrad, DropoutGradientOp<float, CPUContext>);
62 
63 OPERATOR_SCHEMA(Dropout)
64  .NumInputs(1)
65  .NumOutputs(1, 2)
66  .AllowInplace({{0, 0}})
67  .TensorInferenceFunction([](const OperatorDef& def,
68  const vector<TensorShape>& in) {
69  CAFFE_ENFORCE_EQ(1, in.size());
70  vector<TensorShape> out;
71  ArgumentHelper argsHelper(def);
72  out.push_back(in[0]);
73  auto output_mask = !argsHelper.GetSingleArgument<bool>("is_test", 0);
74  if (output_mask) {
75  out.push_back(in[0]);
76  out[1].set_data_type(TensorProto_DataType_BOOL);
77  }
78  return out;
79  })
80  .SetDoc(R"DOC(
81 Dropout takes one input data (Tensor<float>) and produces two Tensor outputs,
82 output (Tensor<float>) and mask (Tensor<bool>). Depending on whether it is in
83 test mode or not, the output Y will either be a random dropout, or a simple
84 copy of the input. Note that our implementation of Dropout does scaling in
85 the training phase, so during testing nothing needs to be done.
86 )DOC")
87  .Arg("ratio", "(float, default 0.5) the ratio of random dropout")
88  .ArgIsTest(
89  "(int) if nonzero, run dropout in test mode where "
90  "the output is simply Y = X.")
91  .Input(0, "data", "The input data as Tensor.")
92  .Output(0, "output", "The output.")
93  .Output(
94  1,
95  "mask",
96  "The output mask. If is_test is nonzero, this output is not filled.")
97  .InheritOnnxSchema("Dropout");
98 
99 OPERATOR_SCHEMA(DropoutGrad)
100  .NumInputs(1, 2)
101  .NumOutputs(1)
102  .AllowInplace({{0, 0}});
103 
104 class GetDropoutGradient : public GradientMakerBase {
105  using GradientMakerBase::GradientMakerBase;
106  vector<OperatorDef> GetGradientDefs() override {
107  ArgumentHelper argshelper(def_);
108  auto is_test = argshelper.GetSingleArgument<bool>("is_test", 0);
109  if (is_test) {
110  return SingleGradientDef(
111  "DropoutGrad", "", vector<string>{GO(0)}, vector<string>{GI(0)});
112  } else {
113  return SingleGradientDef(
114  "DropoutGrad",
115  "",
116  vector<string>{GO(0), O(1)},
117  vector<string>{GI(0)});
118  }
119  }
120 };
121 REGISTER_GRADIENT(Dropout, GetDropoutGradient);
122 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...