1 #ifndef CAFFE2_OPERATORS_GRU_UNIT_OP_H_ 2 #define CAFFE2_OPERATORS_GRU_UNIT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 12 inline T sigmoid(T x) {
13 return 1.0f / (1.0f + exp(-x));
17 inline T host_tanh(T x) {
18 return 2.0f * sigmoid(2.0f * x) - 1.0f;
21 template <
typename T,
typename Context>
28 const int32_t* seqLengths,
32 for (
int n = 0; n < N; ++n) {
33 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
35 for (
int d = 0; d < D; ++d) {
43 const T update = X[1 * D + d];
44 const T output = X[2 * D + d];
45 T sigmoid_update = sigmoid(update);
46 H[d] = H_prev[d] * sigmoid_update +
47 host_tanh(output) * (1.0f - sigmoid_update);
57 template <
typename T,
typename Context>
64 const int32_t* seqLengths,
71 for (
int n = 0; n < N; ++n) {
72 const bool valid = seqLengths ==
nullptr || t < seqLengths[n];
74 for (
int d = 0; d < D; ++d) {
75 T* h_prev_diff = H_prev_diff + d;
76 T* reset_diff = X_diff + 0 * D + d;
77 T* update_diff = X_diff + 1 * D + d;
78 T* output_diff = X_diff + 2 * D + d;
84 *h_prev_diff = H_diff[d];
91 const T u = sigmoid(X[1 * D + d]);
92 const T o = host_tanh(X[2 * D + d]);
94 *h_prev_diff = H_diff[d] * u;
96 *update_diff = (H_diff[d] * H_prev[d] - H_diff[d] * o) * u * (1.0f - u);
97 *output_diff = H_diff[d] * (1.0f - u) * (1.0f - o * o);
112 template <
typename T,
typename Context>
117 drop_states_(OperatorBase::template GetSingleArgument<bool>(
120 sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
123 USE_OPERATOR_CONTEXT_FUNCTIONS;
125 bool RunOnDevice()
override {
127 const size_t TIMESTEP = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
130 const auto N = Input(HIDDEN_T_M_1).dim(1);
133 const auto G = Input(GATES).dim(2);
134 const auto D = Input(HIDDEN_T_M_1).dim(2);
136 CAFFE_ENFORCE_EQ(3 * D, G);
137 const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
138 const auto* X = Input(GATES).template data<T>();
140 const int32_t* seqLengths =
nullptr;
141 if (sequence_lengths_) {
142 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
143 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
148 Output(HIDDEN_T)->ResizeLike(Input(HIDDEN_T_M_1));
149 auto* H = Output(HIDDEN_T)->template mutable_data<T>();
151 detail::GRUUnit<T, Context>(
152 N, D, t, H_prev, X, seqLengths, drop_states_, H, &context_);
157 INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
160 OUTPUT_TAGS(HIDDEN_T);
164 bool sequence_lengths_;
167 template <
typename T,
typename Context>
172 drop_states_(OperatorBase::template GetSingleArgument<bool>(
175 sequence_lengths_(OperatorBase::template GetSingleArgument<bool>(
178 USE_OPERATOR_CONTEXT_FUNCTIONS;
180 bool RunOnDevice()
override {
182 const size_t inputOffset = SEQ_LENGTHS + (sequence_lengths_ ? 1 : 0);
183 const size_t TIMESTEP = inputOffset;
184 const size_t HIDDEN_T = inputOffset + 1;
185 const size_t HIDDEN_T_GRAD = inputOffset + 2;
188 const auto N = Input(HIDDEN_T_M_1).dim(1);
191 const auto G = Input(GATES).dim(2);
192 const auto D = Input(HIDDEN_T_M_1).dim(2);
194 CAFFE_ENFORCE_EQ(3 * D, G);
195 const auto* H_prev = Input(HIDDEN_T_M_1).template data<T>();
196 const auto* X = Input(GATES).template data<T>();
199 const auto* H = Input(HIDDEN_T).template data<T>();
200 const auto* H_diff = Input(HIDDEN_T_GRAD).template data<T>();
202 const int32_t* seqLengths =
nullptr;
203 if (sequence_lengths_) {
204 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
205 seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
208 Output(HIDDEN_T_M_1_GRAD)->ResizeLike(Input(HIDDEN_T_M_1));
209 auto* H_prev_diff = Output(HIDDEN_T_M_1_GRAD)->template mutable_data<T>();
210 Output(GATES_GRAD)->ResizeLike(Input(GATES));
211 auto* X_diff = Output(GATES_GRAD)->template mutable_data<T>();
213 detail::GRUUnitGradient<T, Context>(
230 INPUT_TAGS(HIDDEN_T_M_1, GATES, SEQ_LENGTHS);
231 OUTPUT_TAGS(HIDDEN_T_M_1_GRAD, GATES_GRAD);
235 bool sequence_lengths_;
240 #endif // CAFFE2_OPERATORS_GRU_UNIT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...