1 #ifndef CAFFE2_OPERATORS_LSTM_UNIT_OP_H_ 2 #define CAFFE2_OPERATORS_LSTM_UNIT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/conversions.h" 11 inline T sigmoid(T x) {
12 return 1. / (1. + exp(-x));
16 inline T host_tanh(T x) {
17 return 2. * sigmoid(2. * x) - 1.;
20 template <
typename T,
typename Context>
28 const int32_t* seqLengths,
32 const float forget_bias,
34 for (
int n = 0; n < N; ++n) {
35 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
37 for (
int d = 0; d < D; ++d) {
47 const T i = sigmoid(X[d]);
48 const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
49 const T o = sigmoid(X[2 * D + d]);
50 const T g = host_tanh(X[3 * D + d]);
51 const T c_prev = C_prev[d];
52 const T c = f * c_prev + i * g;
54 const T host_tanh_c = host_tanh(c);
55 H[d] = o * host_tanh_c;
66 template <
typename T,
typename Context>
67 void LSTMUnitGradient(
73 const int32_t* seqLengths,
82 const float forget_bias,
84 for (
int n = 0; n < N; ++n) {
85 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
87 for (
int d = 0; d < D; ++d) {
88 T* c_prev_diff = C_prev_diff + d;
89 T* h_prev_diff = H_prev_diff + d;
90 T* i_diff = X_diff + d;
91 T* f_diff = X_diff + 1 * D + d;
92 T* o_diff = X_diff + 2 * D + d;
93 T* g_diff = X_diff + 3 * D + d;
100 *h_prev_diff = H_diff[d];
101 *c_prev_diff = C_diff[d];
108 const T i = sigmoid(X[d]);
109 const T f = sigmoid(X[1 * D + d] + convert::To<float, T>(forget_bias));
110 const T o = sigmoid(X[2 * D + d]);
111 const T g = host_tanh(X[3 * D + d]);
112 const T c_prev = C_prev[d];
114 const T host_tanh_c = host_tanh(c);
115 const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c);
116 *c_prev_diff = c_term_diff * f;
118 *i_diff = c_term_diff * g * i * (1 - i);
119 *f_diff = c_term_diff * c_prev * f * (1 - f);
120 *o_diff = H_diff[d] * host_tanh_c * o * (1 - o);
121 *g_diff = c_term_diff * i * (1 - g * g);
137 template <
typename Context>
143 static_cast<float>(OperatorBase::template GetSingleArgument<float>(
146 sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
149 drop_states_(OperatorBase::template GetSingleArgument<bool>(
152 USE_OPERATOR_CONTEXT_FUNCTIONS;
155 template <
typename T>
156 bool DoRunWithType() {
158 const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
161 const auto N = Input(CELL_T_M_1).dim(1);
164 const auto G = Input(GATES).dim(2);
165 const auto D = Input(CELL_T_M_1).dim(2);
167 CAFFE_ENFORCE_EQ(4 * D, G);
168 const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
169 const auto* C_prev = Input(CELL_T_M_1).template data<T>();
170 const auto* X = Input(GATES).template data<T>();
172 const int32_t* seqLengths =
nullptr;
173 if (sequence_lengths_) {
174 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
175 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
180 .template data<int32_t>()[0];
181 Output(CELL_T)->ResizeLike(Input(CELL_T_M_1));
182 auto* C = Output(CELL_T)->template mutable_data<T>();
183 Output(HIDDEN_T)->ResizeLike(Input(CELL_T_M_1));
184 auto* H = Output(HIDDEN_T)->template mutable_data<T>();
185 detail::LSTMUnit<T, Context>(
201 bool RunOnDevice()
override {
202 return DoRunWithType<float>();
206 INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
209 OUTPUT_TAGS(HIDDEN_T, CELL_T);
212 bool sequence_lengths_;
218 template <
typename Context>
224 static_cast<float>(OperatorBase::template GetSingleArgument<float>(
227 sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
230 drop_states_(OperatorBase::template GetSingleArgument<bool>(
233 USE_OPERATOR_CONTEXT_FUNCTIONS;
235 template <
typename T>
236 bool DoRunWithType() {
238 const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
239 const size_t TIMESTEP = inputOffset;
240 const size_t HIDDEN_T = inputOffset + 1;
241 const size_t CELL_T = inputOffset + 2;
242 const size_t HIDDEN_T_GRAD = inputOffset + 3;
243 const size_t CELL_T_GRAD = inputOffset + 4;
246 const auto N = Input(CELL_T_M_1).dim(1);
249 const auto G = Input(GATES).dim(2);
250 const auto D = Input(CELL_T_M_1).dim(2);
252 CAFFE_ENFORCE_EQ(4 * D, G);
253 const auto* C_prev = Input(CELL_T_M_1).template data<T>();
254 const auto* X = Input(GATES).template data<T>();
257 .template data<int32_t>()[0];
258 const auto* C = Input(CELL_T).template data<T>();
259 const auto* H = Input(HIDDEN_T).template data<T>();
260 const auto* C_diff = Input(CELL_T_GRAD).template data<T>();
261 const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
263 const int32_t* seqLengths =
nullptr;
264 if (sequence_lengths_) {
265 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
266 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
269 Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
270 auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
271 Output(CELL_T_M_1_GRAD)->ResizeLike(Input(CELL_T_M_1));
272 auto* C_prev_diff = Output(CELL_T_M_1_GRAD)->template mutable_data<T>();
273 Output(GATES_GRAD)->ResizeLike(Input(GATES));
274 auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
276 detail::LSTMUnitGradient<T, Context>(
296 bool RunOnDevice()
override {
297 return DoRunWithType<float>();
301 INPUT_TAGS(HIDDEN_T_M_1, CELL_T_M_1, GATES, SEQ_LENGTHS);
304 OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, CELL_T_M_1_GRAD, GATES_GRAD);
307 bool sequence_lengths_;
314 #endif // CAFFE2_OPERATORS_LSTM_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...