Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_op_cudnn.cc
1 #include "caffe2/operators/rnn/recurrent_op_cudnn.h"
2 #include "caffe2/utils/math.h"
3 
4 #include <map>
5 
6 namespace caffe2 {
7 
8 namespace detail {
9 
10 template <typename T>
11 TensorDescriptors<T>::TensorDescriptors(
12  size_t n,
13  const std::vector<int>& dim,
14  const std::vector<int>& stride) {
15  descs_.resize(n);
16  CAFFE_ENFORCE_EQ(dim.size(), stride.size());
17  for (auto i = 0; i < n; ++i) {
18  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&descs_[i]));
19  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
20  descs_[i],
21  cudnnTypeWrapper<T>::type,
22  dim.size(),
23  dim.data(),
24  stride.data()));
25  }
26 }
27 
28 template <typename T>
29 TensorDescriptors<T>::~TensorDescriptors() {
30  for (auto desc : descs_) {
31  cudnnDestroyTensorDescriptor(desc);
32  }
33 }
34 }
35 
36 template <typename T>
37 RecurrentBaseOp<T>::RecurrentBaseOp(
38  const OperatorDef& operator_def,
39  Workspace* ws)
40  : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
41  CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
42  CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
43  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
44  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
45  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&cxDesc_));
46  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hyDesc_));
47  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&cyDesc_));
48 }
49 
50 template <typename T>
51 RecurrentBaseOp<T>::~RecurrentBaseOp() {
52  CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropoutDesc_));
53  CUDNN_ENFORCE(cudnnDestroyRNNDescriptor(rnnDesc_));
54  CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(wDesc_));
55  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hxDesc_));
56  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(cxDesc_));
57  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hyDesc_));
58  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(cyDesc_));
59 }
60 
61 template <typename T>
62 void RecurrentBaseOp<T>::initialize(
63  const Tensor<CUDAContext>& input,
64  Tensor<CUDAContext>* dropoutStates,
65  Tensor<CUDAContext>* output,
66  Tensor<CUDAContext>* hiddenOutput,
67  Tensor<CUDAContext>* cellOutput) {
68  static_assert(sizeof(T) == 4, ""); // workaround clang bug
69  CAFFE_ENFORCE_GE(input.ndim(), 3);
70  const int seqLength = input.dim(0);
71  const int batchSize = input.dim(1);
72  const int inputDim = input.dim(2);
73  const int hiddenSize = OperatorBase::GetSingleArgument<int>("hidden_size", 0);
74  CAFFE_ENFORCE_GT(hiddenSize, 0);
75  const auto bidirectional =
76  OperatorBase::GetSingleArgument<int>("bidirectional", 0);
77  CAFFE_ENFORCE(bidirectional == 0 || bidirectional == 1);
78  const auto numDirections = bidirectional == 1 ? 2 : 1;
79  const auto outputDim = hiddenSize * numDirections;
80  const auto rnnDirection =
81  bidirectional == 1 ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
82  const auto numLayers = OperatorBase::GetSingleArgument<int>("num_layers", 0);
83  CAFFE_ENFORCE_GT(numLayers, 0);
84  const auto& rnnModeStr =
85  OperatorBase::GetSingleArgument<string>("rnn_mode", "");
86  CAFFE_ENFORCE(rnnModeStr == "lstm" || rnnModeStr == "gru");
87  const auto rnnMode = rnnModeStr == "lstm" ? CUDNN_LSTM : CUDNN_GRU;
88  const auto& rnnInputStr =
89  OperatorBase::GetSingleArgument<string>("input_mode", "");
90  CAFFE_ENFORCE(rnnInputStr == "linear" || rnnInputStr == "skip");
91  const auto rnnInput =
92  rnnInputStr == "linear" ? CUDNN_LINEAR_INPUT : CUDNN_SKIP_INPUT;
93 
94  // Dropout setup
95  {
96  if (dropoutStates) {
97  size_t stateSize;
98  float dropout_param =
99  OperatorBase::GetSingleArgument<float>("dropout", 1.0);
100  if (dropout_param < 1.0) {
101  CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
102  cudnn_wrapper_.inline_cudnn_handle(), &stateSize));
103  dropoutStates->Resize(std::vector<int>{static_cast<int>(
104  stateSize / 4 /* sizeof(T) - workaround clang bug */)});
105  CUDNN_ENFORCE(cudnnSetDropoutDescriptor(
106  dropoutDesc_,
107  cudnn_wrapper_.inline_cudnn_handle(),
108  dropout_param,
109  dropoutStates->template mutable_data<T>(),
110  stateSize,
111  OperatorBase::GetSingleArgument<int>("seed", 0)));
112  }
113  }
114  }
115 
116  // RNN setup
117  {
118 #if CUDNN_VERSION_MIN(7, 0, 0)
119  CUDNN_ENFORCE(cudnnSetRNNDescriptor(
120  cudnn_wrapper_.inline_cudnn_handle(),
121  rnnDesc_,
122  hiddenSize,
123  numLayers,
124  dropoutDesc_,
125  rnnInput,
126  rnnDirection,
127  rnnMode,
128  CUDNN_RNN_ALGO_STANDARD, // TODO: verify correctness / efficiency.
129  cudnnTypeWrapper<T>::type));
130 #else
131  CUDNN_ENFORCE(cudnnSetRNNDescriptor(
132  rnnDesc_,
133  hiddenSize,
134  numLayers,
135  dropoutDesc_,
136  rnnInput,
137  rnnDirection,
138  rnnMode,
139  cudnnTypeWrapper<T>::type));
140 #endif
141  }
142  // X setup
143  {
144  xDesc_.reset(new detail::TensorDescriptors<T>(
145  seqLength,
146  // Third dimension is unused
147  {batchSize, inputDim, 1},
148  // Fully-packed
149  {inputDim, 1, 1}));
150  }
151  // Y setup
152  {
153  yDesc_.reset(new detail::TensorDescriptors<T>(
154  seqLength,
155  // Third dimension is unused
156  {batchSize, hiddenSize * numDirections, 1},
157  // Fully-packed
158  {numDirections * hiddenSize, 1, 1}));
159 
160  if (output) {
161  output->Resize(std::vector<int>{seqLength, batchSize, outputDim});
162  }
163  }
164 
165  // Hidden/Cell setup
166  {
167  const std::array<int, 3> dim{
168  numLayers * numDirections, batchSize, hiddenSize};
169  const std::array<int, 3> stride{batchSize * hiddenSize, hiddenSize, 1};
170  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
171  hxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
172  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
173  cxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
174  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
175  hyDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
176  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
177  cyDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
178 
179  if (hiddenOutput) {
180  hiddenOutput->Resize(
181  std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
182  }
183 
184  if (cellOutput) {
185  cellOutput->Resize(
186  std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
187  }
188  }
189 
190  // Weights setup
191  {
192  size_t weightsSize;
193  CUDNN_ENFORCE(cudnnGetRNNParamsSize(
194  cudnn_wrapper_.inline_cudnn_handle(),
195  rnnDesc_,
196  xDesc_->descs()[0],
197  &weightsSize,
198  cudnnTypeWrapper<T>::type));
199  const std::array<int, 3> dims{
200  static_cast<int>(
201  weightsSize / 4 /* sizeof(T) - workaround clang bug */),
202  1,
203  1};
204  CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
205  wDesc_, cudnnTypeWrapper<T>::type, CUDNN_TENSOR_NCHW, 3, dims.data()));
206  }
207 
208  // RNN workspace size
209  {
210  CUDNN_ENFORCE(cudnnGetRNNWorkspaceSize(
211  cudnn_wrapper_.inline_cudnn_handle(),
212  rnnDesc_,
213  seqLength,
214  xDesc_->descs(),
215  &cudnnWsNbytes_));
216  }
217 }
218 
219 template <typename T>
220 bool RecurrentOp<T>::RunOnDevice() {
221  const int seqLength = Input(INPUT).dim32(0);
222  if (Input(INPUT).dims() != cachedInputDims_) {
223  initialize(
224  Input(INPUT),
225  Output(DROPOUT_STATES),
226  Output(OUTPUT),
227  Output(HIDDEN_OUTPUT),
228  Output(CELL_OUTPUT));
229  cachedInputDims_ = Input(INPUT).dims();
230  }
231 
232  // Validation checks
233  size_t weightsSize;
234  CUDNN_ENFORCE(cudnnGetRNNParamsSize(
235  cudnn_wrapper_.inline_cudnn_handle(),
236  rnnDesc_,
237  xDesc_->descs()[0],
238  &weightsSize,
239  cudnnTypeWrapper<T>::type));
240  CAFFE_ENFORCE_EQ(Input(WEIGHT).nbytes(), weightsSize);
241 
242  // Training reserve size
243  CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
244  cudnn_wrapper_.inline_cudnn_handle(),
245  rnnDesc_,
246  seqLength,
247  xDesc_->descs(),
248  &reserveNbytes_));
249  Output(RNN_SCRATCH)
250  ->Resize(std::vector<int>{static_cast<int>(
251  reserveNbytes_ / 4)}); // sizeof(T) - workaround clang bug
252  Output(RNN_SCRATCH)->template mutable_data<T>();
253 
254  auto InputData = [this](int i) { return this->Input(i).template data<T>(); };
255  auto OutputData = [this](int i) {
256  return this->Output(i)->template mutable_data<T>();
257  };
258 
259  if (OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
260  cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
261  CUDNN_ENFORCE(cudnnRNNForwardInference(
262  state->cudnn_handle(),
263  rnnDesc_,
264  seqLength,
265  xDesc_->descs(),
266  InputData(INPUT), //.template data<T>(),
267  hxDesc_,
268  InputData(HIDDEN_INPUT), //.template data<T>(),
269  cxDesc_,
270  InputData(CELL_INPUT), //.template data<T>(),
271  wDesc_,
272  InputData(WEIGHT), //.template data<T>(),
273  yDesc_->descs(),
274  OutputData(OUTPUT), //->template mutable_data<T>(),
275  hyDesc_,
276  OutputData(HIDDEN_OUTPUT), //->template mutable_data<T>(),
277  cyDesc_,
278  OutputData(CELL_OUTPUT), //->template mutable_data<T>(),
279  state->workspace().get(cudnnWsNbytes_),
280  cudnnWsNbytes_));
281  });
282  } else {
283  cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
284  CUDNN_ENFORCE(cudnnRNNForwardTraining(
285  state->cudnn_handle(),
286  rnnDesc_,
287  seqLength,
288  xDesc_->descs(),
289  InputData(INPUT), //.template data<T>(),
290  hxDesc_,
291  InputData(HIDDEN_INPUT), //.template data<T>(),
292  cxDesc_,
293  InputData(CELL_INPUT), //.template data<T>(),
294  wDesc_,
295  InputData(WEIGHT), //.template data<T>(),
296  yDesc_->descs(),
297  OutputData(OUTPUT), //->template mutable_data<T>(),
298  hyDesc_,
299  OutputData(HIDDEN_OUTPUT), //->template mutable_data<T>(),
300  cyDesc_,
301  OutputData(CELL_OUTPUT), //->template mutable_data<T>(),
302  state->workspace().get(cudnnWsNbytes_),
303  cudnnWsNbytes_,
304  OutputData(RNN_SCRATCH), //->template mutable_data<T>(),
305  reserveNbytes_));
306  });
307  }
308 
309  return true;
310 }
311 
312 template <typename T>
313 bool RecurrentGradientOp<T>::RunOnDevice() {
314  const int seqLength = Input(INPUT).dim32(0);
315  if (Input(INPUT).dims() != cachedInputDims_) {
316  initialize(Input(INPUT), Output(DROPOUT_STATES));
317  cachedInputDims_ = Input(INPUT).dims();
318  }
319  CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
320  cudnn_wrapper_.inline_cudnn_handle(),
321  rnnDesc_,
322  seqLength,
323  xDesc_->descs(),
324  &reserveNbytes_));
325  CAFFE_ENFORCE_EQ(reserveNbytes_, Input(RNN_SCRATCH).nbytes());
326  Output(GRAD_INPUT)->ResizeLike(Input(INPUT));
327  Output(GRAD_HIDDEN_INPUT)->ResizeLike(Input(HIDDEN_INPUT));
328  Output(GRAD_CELL_INPUT)->ResizeLike(Input(CELL_INPUT));
329 
330  Output(GRAD_WEIGHT)->ResizeLike(Input(WEIGHT));
331  math::Set<T, CUDAContext>(
332  Output(GRAD_WEIGHT)->size(),
333  0.0,
334  Output(GRAD_WEIGHT)->template mutable_data<T>(),
335  &context_);
336 
337 #if CUDNN_VERSION_MIN(6,0,0)
338  auto * reserve = Output(RNN_SCRATCH_OUT)->template mutable_data<T>();
339 #else
340  const auto * reserve = Output(RNN_SCRATCH_OUT)->template data<T>();
341 #endif
342  auto InputData = [this](int i) { return this->Input(i).template data<T>(); };
343  auto OutputData = [this](int i) {
344  return this->Output(i)->template mutable_data<T>();
345  };
346 
347  cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
348  CUDNN_ENFORCE(cudnnRNNBackwardData(
349  state->cudnn_handle(),
350  rnnDesc_,
351  seqLength,
352  yDesc_->descs(),
353  InputData(OUTPUT), // Input(OUTPUT).template data<T>(),
354  yDesc_->descs(),
355  InputData(GRAD_OUTPUT), // Input(GRAD_OUTPUT).template data<T>(),
356  hyDesc_,
357  // Note: like CNTK, ignore these gradient inputs. t16675365 to
358  // reconsider.
359  nullptr,
360  cyDesc_,
361  nullptr,
362  wDesc_,
363  InputData(WEIGHT), // Input(WEIGHT).template data<T>(),
364  hxDesc_,
365  InputData(HIDDEN_INPUT), // Input(HIDDEN_INPUT).template data<T>(),
366  cxDesc_,
367  InputData(CELL_INPUT),
368  xDesc_->descs(),
369  OutputData(GRAD_INPUT),
370  hxDesc_,
371  OutputData(GRAD_HIDDEN_INPUT),
372  cxDesc_,
373  OutputData(GRAD_CELL_INPUT),
374  state->workspace().get(cudnnWsNbytes_),
375  cudnnWsNbytes_,
376  reserve,
377  reserveNbytes_));
378  CUDNN_ENFORCE(cudnnRNNBackwardWeights(
379  state->cudnn_handle(),
380  rnnDesc_,
381  seqLength,
382  xDesc_->descs(),
383  InputData(INPUT), // Input(INPUT).template data<T>(),
384  hxDesc_,
385  InputData(HIDDEN_INPUT), // Input(HIDDEN_INPUT).template data<T>(),
386  yDesc_->descs(),
387  InputData(OUTPUT), // Input(OUTPUT).template data<T>(),
388  state->workspace().get(cudnnWsNbytes_),
389  cudnnWsNbytes_,
390  wDesc_,
391  OutputData(
392  GRAD_WEIGHT), // Output(GRAD_WEIGHT)->template mutable_data<T>(),
393  reserve,
394  reserveNbytes_));
395  });
396 
397  return true;
398 }
399 
400 template <typename T, RecurrentParamOpMode mode>
401 bool RecurrentParamAccessOp<T, mode>::RunOnDevice() {
402  initialize(Input(0));
403 
404  if (mode == SET_PARAM) {
405  size_t paramsSize;
406  CUDNN_ENFORCE(cudnnGetRNNParamsSize(
407  cudnn_wrapper_.inline_cudnn_handle(),
408  rnnDesc_,
409  xDesc_->descs()[0],
410  &paramsSize,
411  cudnnTypeWrapper<T>::type));
412 
413  CAFFE_ENFORCE_EQ(
414  paramsSize / 4, Input(1).size(), "Incorrect weight initialization");
415  }
416 
417  int layer = OperatorBase::GetSingleArgument<int>("layer", 0);
418  std::string param_type =
419  OperatorBase::GetSingleArgument<string>("param_type", "");
420  std::string input_type =
421  OperatorBase::GetSingleArgument<string>("input_type", "");
422 
423  // Mapping to CUDNN constants
424  std::map<string, int> weight_constants = {{"input_gate_w", 0},
425  {"forget_gate_w", 1},
426  {"cell_w", 2},
427  {"output_gate_w", 3}};
428  std::map<string, int> bias_constants = {{"input_gate_b", 0},
429  {"forget_gate_b", 1},
430  {"cell_b", 2},
431  {"output_gate_b", 3}};
432  if (bias_constants.find(param_type) != bias_constants.end()) {
433  int param_id = bias_constants[param_type] + 4 * (input_type == "recurrent");
434 
435  cudnnFilterDescriptor_t biasDesc;
436  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&biasDesc));
437  void* bias;
438 
439  CUDNN_ENFORCE(cudnnGetRNNLinLayerBiasParams(
440  cudnn_wrapper_.inline_cudnn_handle(),
441  rnnDesc_,
442  layer,
443  xDesc_->descs()[0],
444  wDesc_,
445  Input(1).template data<T>(),
446  param_id, // Forget gate bias for recurrent input
447  biasDesc,
448  &bias));
449  int numBiasDims;
450  std::vector<int> biasDims(3);
451  cudnnDataType_t dt;
452  cudnnTensorFormat_t tf;
453  // For some reason, the CuDNN Bias tensor is 3 dimensional
454  CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
455  biasDesc, 3, &dt, &tf, &numBiasDims, biasDims.data()));
456  CAFFE_ENFORCE_EQ(numBiasDims, 3);
457 
458  if (mode == SET_PARAM) {
459  CAFFE_ENFORCE_EQ(
460  biasDims[0] * biasDims[1] * biasDims[2], Input(2).size());
461  context_.template Copy<T, CUDAContext, CUDAContext>(
462  biasDims[0] * biasDims[1] * biasDims[2],
463  Input(2).template data<T>(),
464  static_cast<T*>(bias));
465  } else {
466  Output(0)->Resize(biasDims);
467  context_.template Copy<T, CUDAContext, CUDAContext>(
468  biasDims[0] * biasDims[1] * biasDims[2],
469  static_cast<T*>(bias),
470  Output(0)->template mutable_data<T>());
471  }
472  } else if (weight_constants.find(param_type) != weight_constants.end()) {
473  int param_id =
474  weight_constants[param_type] + 4 * (input_type == "recurrent");
475  cudnnFilterDescriptor_t matrixParamDesc;
476  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&matrixParamDesc));
477  void* pmatrix;
478  CUDNN_ENFORCE(cudnnGetRNNLinLayerMatrixParams(
479  cudnn_wrapper_.inline_cudnn_handle(),
480  rnnDesc_,
481  layer,
482  xDesc_->descs()[0],
483  wDesc_,
484  Input(1).template data<T>(),
485  param_id, // Forget gate bias for recurrent input
486  matrixParamDesc,
487  &pmatrix));
488  int numDims;
489  std::vector<int> matDims(3);
490  cudnnDataType_t dt;
491  cudnnTensorFormat_t tf;
492 
493  CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
494  matrixParamDesc, 3, &dt, &tf, &numDims, matDims.data()));
495  CAFFE_ENFORCE_EQ(numDims, 3);
496  if (mode == SET_PARAM) {
497  CAFFE_ENFORCE_EQ(matDims[0] * matDims[1] * matDims[2], Input(2).size());
498  context_.template Copy<T, CUDAContext, CUDAContext>(
499  matDims[0] * matDims[1] * matDims[2],
500  Input(2).template data<T>(),
501  static_cast<T*>(pmatrix));
502  } else {
503  Output(0)->Resize(matDims);
504  context_.template Copy<T, CUDAContext, CUDAContext>(
505  matDims[0] * matDims[1] * matDims[2],
506  static_cast<T*>(pmatrix),
507  Output(0)->template mutable_data<T>());
508  }
509  } else {
510  CAFFE_ENFORCE(false, "Unknown param type:", param_type);
511  }
512 
513  return true;
514 }
515 
516 REGISTER_CUDNN_OPERATOR(Recurrent, RecurrentOp<float>);
517 OPERATOR_SCHEMA(Recurrent).NumInputs(4).NumOutputs(5).SetDoc(R"DOC(
518 
519 Recurrent wraps the CuDNN R5 RNN implementation. See the CuDNN R5
520 documentation for more information.
521 
522 In general, the implementation takes an input (TxNxD) tensor, the
523 hidden state input (NxD), the cell input (NxD), and a weight tensor
524 (effectively an opaque blob, where the size and layout is dictated by
525 CuDNN).
526 
527 The outputs are the output (again, TxNxD), the final hidden/cell
528 states (NxD). These can be reset (at sequence boundaries across
529 minibatches) by multiplying by zero.
530 
531 The CuDNN arguments (hidden_size, bidirectional, num_layers, rnn_mode,
532 input_mode) are passed directly through to CuDNN.
533 
534 )DOC");
535 REGISTER_CUDNN_OPERATOR(RecurrentGradient, RecurrentGradientOp<float>);
536 OPERATOR_SCHEMA(RecurrentGradient)
537  .NumInputs(7)
538  .NumOutputs(6)
539  .AllowInplace({{4, 5}});
540 
541 REGISTER_CUDNN_OPERATOR(
542  RecurrentParamSet,
543  RecurrentParamAccessOp<float, SET_PARAM>);
544 OPERATOR_SCHEMA(RecurrentParamSet)
545  .NumInputs(3)
546  .NumOutputs(1)
547  .EnforceInplace({{1, 0}})
548  .SetDoc("Set individual parameters of a recurrent net.")
549  .Arg("param_type", R"DOC(Type of param to be set:
550  "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w"
551  "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b"
552  )DOC")
553  .Arg("input_type", "'recurrent' or 'input'")
554  .Arg("layer", "layer index (starting from 0)")
555  .Input(0, "input", R"DOC(Input blob. Needed for inferring the shapes.
556  A dummy tensor matching the input shape is ok.)DOC")
557  .Input(1, "all_params", "Blob holding all the parameters")
558  .Input(2, "param", "Values for the specified parameter")
559  .Output(
560  0,
561  "all_params",
562  "Blob holding all the parameters (same as input(1))");
563 
564 REGISTER_CUDNN_OPERATOR(
565  RecurrentParamGet,
566  RecurrentParamAccessOp<float, GET_PARAM>);
567 OPERATOR_SCHEMA(RecurrentParamGet)
568  .NumInputs(2)
569  .NumOutputs(1)
570  .SetDoc("Retrieve individual parameters of a recurrent net op.")
571  .Arg("param_type", R"DOC(Type of param to be set:
572  "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w"
573  "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b"
574  )DOC")
575  .Arg("input_type", "'recurrent' or 'input'")
576  .Arg("layer", "layer index (starting from 0)")
577  .Input(0, "input", R"DOC(Input blob. Needed for inferring the shapes.
578  A dummy tensor matching the input shape is ok.)DOC")
579  .Input(1, "all_params", "Blob holding all the parameters")
580  .Output(0, "param", "Blob holding the requested values");
581 
583  using GradientMakerBase::GradientMakerBase;
584  vector<OperatorDef> GetGradientDefs() override {
585  return SingleGradientDef(
586  "RecurrentGradient",
587  "",
588  vector<string>{I(0), // INPUT
589  I(1), // HIDDEN_INPUT
590  I(2), // CELL_INPUT
591  I(3), // WEIGHT
592  O(3), // RNN_SCRATCH
593  O(0), // OUTPUT
594  GO(0)}, // GRAD_OUTPUT
595  // TODO: not currently using these gradients, investigate t16675365
596  // GO(1), // GRAD_HIDDEN_OUTPUT
597  // GO(2)}, // GRAD_CELL_OUTPUT
598  vector<string>{
599  GI(0), // GRAD_INPUT
600  GI(1), // GRAD_HIDDEN_INPUT
601  GI(2), // GRAD_CELL_INPUT
602  GI(3), // GRAD_WEIGHT
603  O(4), // DROPOUT_STATES
604  O(3) // RNN_SCRATCH
605  });
606  }
607 };
608 REGISTER_GRADIENT(Recurrent, GetRecurrentGradient);
609 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...