Caffe2 - C++ API
A deep learning, cross platform ML framework
gru_unit_op.h
1 #ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_
2 #define CAFFE2_OPERATORS_GRU_UNIT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 namespace detail {
10 
11 template <typename T>
12 inline T sigmoid(T x) {
13  return 1.0f / (1.0f + exp(-x));
14 }
15 
16 template <typename T>
17 inline T host_tanh(T x) {
18  return 2.0f * sigmoid(2.0f * x) - 1.0f;
19 }
20 
21 template <typename T, typename Context>
22 void GRUUnit(
23  int N,
24  int D,
25  int t,
26  const T* H_prev,
27  const T* X,
28  const int32_t* seqLengths,
29  bool drop_states,
30  T* H,
31  Context* /*context*/) {
32  for (int n = 0; n < N; ++n) {
33  const bool valid = seqLengths == nullptr || t < seqLengths[n];
34 
35  for (int d = 0; d < D; ++d) {
36  if (!valid) {
37  if (drop_states) {
38  H[d] = 0;
39  } else {
40  H[d] = H_prev[d];
41  }
42  } else {
43  const T update = X[1 * D + d];
44  const T output = X[2 * D + d];
45  T sigmoid_update = sigmoid(update);
46  H[d] = H_prev[d] * sigmoid_update +
47  host_tanh(output) * (1.0f - sigmoid_update);
48  }
49  }
50 
51  H_prev += D;
52  X += 3 * D;
53  H += D;
54  }
55 }
56 
57 template <typename T, typename Context>
58 void GRUUnitGradient(
59  int N,
60  int D,
61  int t,
62  const T* H_prev,
63  const T* X,
64  const int32_t* seqLengths,
65  const T* H,
66  const T* H_diff,
67  bool drop_states,
68  T* H_prev_diff,
69  T* X_diff,
70  Context* /*context*/) {
71  for (int n = 0; n < N; ++n) {
72  const bool valid = seqLengths == nullptr || t < seqLengths[n];
73 
74  for (int d = 0; d < D; ++d) {
75  T* h_prev_diff = H_prev_diff + d;
76  T* reset_diff = X_diff + 0 * D + d;
77  T* update_diff = X_diff + 1 * D + d;
78  T* output_diff = X_diff + 2 * D + d;
79 
80  if (!valid) {
81  if (drop_states) {
82  *h_prev_diff = 0;
83  } else {
84  *h_prev_diff = H_diff[d];
85  }
86  *reset_diff = 0;
87  *update_diff = 0;
88  *output_diff = 0;
89  } else {
90  // Calculate Gate Outputs
91  const T u = sigmoid(X[1 * D + d]);
92  const T o = host_tanh(X[2 * D + d]);
93 
94  *h_prev_diff = H_diff[d] * u;
95  *reset_diff = 0; // 0 contribution to gradient from this operation
96  *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
97  *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
98  }
99  }
100 
101  H_prev += D;
102  X += 3 * D;
103  H += D;
104  H_diff += D;
105  X_diff += 3 * D;
106  H_prev_diff += D;
107  }
108 }
109 
110 } // namespace detail
111 
112 template <typename T, typename Context>
113 class GRUUnitOp : public Operator<Context> {
114  public:
115  GRUUnitOp(const OperatorDef& operator_def, Workspace* ws)
116  : Operator<Context>(operator_def, ws),
117  drop_states_(OperatorBase::template GetSingleArgument<bool>(
118  "drop_states",
119  false)),
120  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
121  "sequence_lengths",
122  true)) {}
123  USE_OPERATOR_CONTEXT_FUNCTIONS;
124 
125  bool RunOnDevice() override {
126  // handle potentially-missing sequence lengths input
127  const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
128 
129  // Extract N
130  const auto N = Input(HIDDEN_T_M_1).dim(1);
131 
132  // Gates: 1xNxG
133  const auto G = Input(GATES).dim(2);
134  const auto D = Input(HIDDEN_T_M_1).dim(2);
135 
136  CAFFE_ENFORCE_EQ(3 * D, G);
137  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
138  const auto* X = Input(GATES).template data<T>();
139 
140  const int32_t* seqLengths = nullptr;
141  if (sequence_lengths_) {
142  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
143  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
144  }
145 
146  const auto t = static_cast<OperatorBase*>(this)->
147  Input<Tensor<CPUContext>>(TIMESTEP).template data<int32_t>()[0];
148  Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1));
149  auto* H = Output(HIDDEN_T)->template mutable_data<T>();
150 
151  detail::GRUUnit<T, Context>(
152  N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
153  return true;
154  }
155 
156  protected:
157  INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
158  // additional input tags are determined dynamically based on whether
159  // sequence_lengths is present.
160  OUTPUT_TAGS(HIDDEN_T);
161 
162  private:
163  bool drop_states_;
164  bool sequence_lengths_;
165 };
166 
167 template <typename T, typename Context>
168 class GRUUnitGradientOp : public Operator<Context> {
169  public:
170  GRUUnitGradientOp(const OperatorDef& operator_def, Workspace* ws)
171  : Operator<Context>(operator_def, ws),
172  drop_states_(OperatorBase::template GetSingleArgument<bool>(
173  "drop_states",
174  false)),
175  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
176  "sequence_lengths",
177  true)) {}
178  USE_OPERATOR_CONTEXT_FUNCTIONS;
179 
180  bool RunOnDevice() override {
181  // handle potentially-missing sequence lengths input
182  const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
183  const size_t TIMESTEP = inputOffset;
184  const size_t HIDDEN_T = inputOffset + 1;
185  const size_t HIDDEN_T_GRAD = inputOffset + 2;
186 
187  // Extract N
188  const auto N = Input(HIDDEN_T_M_1).dim(1);
189 
190  // Gates: 1xNxG
191  const auto G = Input(GATES).dim(2);
192  const auto D = Input(HIDDEN_T_M_1).dim(2);
193 
194  CAFFE_ENFORCE_EQ(3 * D, G);
195  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
196  const auto* X = Input(GATES).template data<T>();
197  const auto t = static_cast<OperatorBase*>(this)->
198  Input<Tensor<CPUContext>>(TIMESTEP).template data<int32_t>()[0];
199  const auto* H = Input(HIDDEN_T).template data<T>();
200  const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
201 
202  const int32_t* seqLengths = nullptr;
203  if (sequence_lengths_) {
204  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
205  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
206  }
207 
208  Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
209  auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
210  Output(GATES_GRAD)->ResizeLike(Input(GATES));
211  auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
212 
213  detail::GRUUnitGradient<T, Context>(
214  N,
215  D,
216  t,
217  H_prev,
218  X,
219  seqLengths,
220  H,
221  H_diff,
222  drop_states_,
223  H_prev_diff,
224  X_diff,
225  &context_);
226  return true;
227  }
228 
229  protected:
230  INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
231  OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD);
232 
233  private:
234  bool drop_states_;
235  bool sequence_lengths_;
236 };
237 
238 } // namespace caffe2
239 
240 #endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...