Caffe2 - C++ API
A deep learning, cross platform ML framework
lstm_unit_op.h
1 #ifndef CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
2 #define CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/conversions.h"
7 
8 namespace caffe2 {
9 namespace detail {
10 template <typename T>
11 inline T sigmoid(T x) {
12  return 1. / (1. + exp(-x));
13 }
14 
15 template <typename T>
16 inline T host_tanh(T x) {
17  return 2. * sigmoid(2. * x) - 1.;
18 }
19 
20 template <typename T, typename Context>
21 void LSTMUnit(
22  int N,
23  int D,
24  int t,
25  const T* H_prev,
26  const T* C_prev,
27  const T* X,
28  const int32_t* seqLengths,
29  bool drop_states,
30  T* C,
31  T* H,
32  const float forget_bias,
33  Context* /*context*/) {
34  for (int n = 0; n < N; ++n) {
35  const bool valid = seqLengths == nullptr || t < seqLengths[n];
36 
37  for (int d = 0; d < D; ++d) {
38  if (!valid) {
39  if (drop_states) {
40  H[d] = 0;
41  C[d] = 0;
42  } else {
43  H[d] = H_prev[d];
44  C[d] = C_prev[d];
45  }
46  } else {
47  const T i = sigmoid(X[d]);
48  const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
49  const T o = sigmoid(X[2 * D + d]);
50  const T g = host_tanh(X[3 * D + d]);
51  const T c_prev = C_prev[d];
52  const T c = f * c_prev + i * g;
53  C[d] = c;
54  const T host_tanh_c = host_tanh(c);
55  H[d] = o * host_tanh_c;
56  }
57  }
58  H_prev += D;
59  C_prev += D;
60  X += 4 * D;
61  C += D;
62  H += D;
63  }
64 }
65 
66 template <typename T, typename Context>
67 void LSTMUnitGradient(
68  int N,
69  int D,
70  int t,
71  const T* C_prev,
72  const T* X,
73  const int32_t* seqLengths,
74  const T* C,
75  const T* H,
76  const T* C_diff,
77  const T* H_diff,
78  bool drop_states,
79  T* H_prev_diff,
80  T* C_prev_diff,
81  T* X_diff,
82  const float forget_bias,
83  Context* /*context*/) {
84  for (int n = 0; n < N; ++n) {
85  const bool valid = seqLengths == nullptr || t < seqLengths[n];
86 
87  for (int d = 0; d < D; ++d) {
88  T* c_prev_diff = C_prev_diff + d;
89  T* h_prev_diff = H_prev_diff + d;
90  T* i_diff = X_diff + d;
91  T* f_diff = X_diff + 1 * D + d;
92  T* o_diff = X_diff + 2 * D + d;
93  T* g_diff = X_diff + 3 * D + d;
94 
95  if (!valid) {
96  if (drop_states) {
97  *h_prev_diff = 0;
98  *c_prev_diff = 0;
99  } else {
100  *h_prev_diff = H_diff[d];
101  *c_prev_diff = C_diff[d];
102  }
103  *i_diff = 0;
104  *f_diff = 0;
105  *o_diff = 0;
106  *g_diff = 0;
107  } else {
108  const T i = sigmoid(X[d]);
109  const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
110  const T o = sigmoid(X[2 * D + d]);
111  const T g = host_tanh(X[3 * D + d]);
112  const T c_prev = C_prev[d];
113  const T c = C[d];
114  const T host_tanh_c = host_tanh(c);
115  const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
116  *c_prev_diff = c_term_diff * f;
117  *h_prev_diff = 0; // not used in 'valid' case
118  *i_diff = c_term_diff * g * i * (1 - i);
119  *f_diff = c_term_diff * c_prev * f * (1 - f);
120  *o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
121  *g_diff = c_term_diff * i * (1 - g * g);
122  }
123  }
124  C_prev += D;
125  X += 4 * D;
126  C += D;
127  H += D;
128  C_diff += D;
129  H_diff += D;
130  X_diff += 4 * D;
131  H_prev_diff += D;
132  C_prev_diff += D;
133  }
134 }
135 } // namespace detail
136 
137 template <typename Context>
138 class LSTMUnitOp : public Operator<Context> {
139  public:
140  LSTMUnitOp(const OperatorDef& operator_def, Workspace* ws)
141  : Operator<Context>(operator_def, ws),
142  forget_bias_(
143  static_cast<float>(OperatorBase::template GetSingleArgument<float>(
144  "forget_bias",
145  0.0))),
146  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
147  "sequence_lengths",
148  true)),
149  drop_states_(OperatorBase::template GetSingleArgument<bool>(
150  "drop_states",
151  false)) {}
152  USE_OPERATOR_CONTEXT_FUNCTIONS;
154 
155  template <typename T>
156  bool DoRunWithType() {
157  // handle potentially-missing sequence lengths input
158  const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
159 
160  // Extract N
161  const auto N = Input(CELL_T_M_1).dim(1);
162 
163  // Gates: 1xNxG
164  const auto G = Input(GATES).dim(2);
165  const auto D = Input(CELL_T_M_1).dim(2);
166 
167  CAFFE_ENFORCE_EQ(4 * D, G);
168  const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
169  const auto* C_prev = Input(CELL_T_M_1).template data<T>();
170  const auto* X = Input(GATES).template data<T>();
171 
172  const int32_t* seqLengths = nullptr;
173  if (sequence_lengths_) {
174  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
175  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
176  }
177 
178  const auto t = static_cast<OperatorBase*>(this)
179  ->Input<Tensor<CPUContext>>(TIMESTEP)
180  .template data<int32_t>()[0];
181  Output(CELL_T)->ResizeLike(Input(CELL_T_M_1));
182  auto* C = Output(CELL_T)->template mutable_data<T>();
183  Output(HIDDEN_T)->ResizeLike(Input(CELL_T_M_1));
184  auto* H = Output(HIDDEN_T)->template mutable_data<T>();
185  detail::LSTMUnit<T, Context>(
186  N,
187  D,
188  t,
189  H_prev,
190  C_prev,
191  X,
192  seqLengths,
193  drop_states_,
194  C,
195  H,
196  forget_bias_,
197  &context_);
198  return true;
199  }
200 
201  bool RunOnDevice() override {
202  return DoRunWithType<float>();
203  }
204 
205  protected:
206  INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
207  // additional input tags are determined dynamically based on whether
208  // sequence_lengths is present.
209  OUTPUT_TAGS(HIDDEN_T, CELL_T);
210 
211  float forget_bias_;
212  bool sequence_lengths_;
213 
214  private:
215  bool drop_states_;
216 };
217 
218 template <typename Context>
219 class LSTMUnitGradientOp : public Operator<Context> {
220  public:
221  LSTMUnitGradientOp(const OperatorDef& operator_def, Workspace* ws)
222  : Operator<Context>(operator_def, ws),
223  forget_bias_(
224  static_cast<float>(OperatorBase::template GetSingleArgument<float>(
225  "forget_bias",
226  0.0))),
227  sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
228  "sequence_lengths",
229  true)),
230  drop_states_(OperatorBase::template GetSingleArgument<bool>(
231  "drop_states",
232  false)) {}
233  USE_OPERATOR_CONTEXT_FUNCTIONS;
234 
235  template <typename T>
236  bool DoRunWithType() {
237  // handle potentially-missing sequence lengths input
238  const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
239  const size_t TIMESTEP = inputOffset;
240  const size_t HIDDEN_T = inputOffset + 1;
241  const size_t CELL_T = inputOffset + 2;
242  const size_t HIDDEN_T_GRAD = inputOffset + 3;
243  const size_t CELL_T_GRAD = inputOffset + 4;
244 
245  // Extract N
246  const auto N = Input(CELL_T_M_1).dim(1);
247 
248  // Gates: 1xNxG
249  const auto G = Input(GATES).dim(2);
250  const auto D = Input(CELL_T_M_1).dim(2);
251 
252  CAFFE_ENFORCE_EQ(4 * D, G);
253  const auto* C_prev = Input(CELL_T_M_1).template data<T>();
254  const auto* X = Input(GATES).template data<T>();
255  const auto t = static_cast<OperatorBase*>(this)
256  ->Input<Tensor<CPUContext>>(TIMESTEP)
257  .template data<int32_t>()[0];
258  const auto* C = Input(CELL_T).template data<T>();
259  const auto* H = Input(HIDDEN_T).template data<T>();
260  const auto* C_diff = Input(CELL_T_GRAD).template data<T>();
261  const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
262 
263  const int32_t* seqLengths = nullptr;
264  if (sequence_lengths_) {
265  CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
266  seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
267  }
268 
269  Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
270  auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
271  Output(CELL_T_M_1_GRAD)->ResizeLike(Input(CELL_T_M_1));
272  auto* C_prev_diff = Output(CELL_T_M_1_GRAD)->template mutable_data<T>();
273  Output(GATES_GRAD)->ResizeLike(Input(GATES));
274  auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
275 
276  detail::LSTMUnitGradient<T, Context>(
277  N,
278  D,
279  t,
280  C_prev,
281  X,
282  seqLengths,
283  C,
284  H,
285  C_diff,
286  H_diff,
287  drop_states_,
288  H_prev_diff,
289  C_prev_diff,
290  X_diff,
291  forget_bias_,
292  &context_);
293  return true;
294  }
295 
296  bool RunOnDevice() override {
297  return DoRunWithType<float>();
298  }
299 
300  protected:
301  INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
302  // additional input tags are determined dynamically based on whether
303  // sequence_lengths is present.
304  OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, CELL_T_M_1_GRAD, GATES_GRAD);
305 
306  float forget_bias_;
307  bool sequence_lengths_;
308 
309  private:
310  bool drop_states_;
311 };
312 } // namespace caffe2
313 
314 #endif // CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...