1 #include "caffe2/operators/rnn/recurrent_op_cudnn.h" 2 #include "caffe2/utils/math.h" 11 TensorDescriptors<T>::TensorDescriptors(
13 const std::vector<int>& dim,
14 const std::vector<int>& stride) {
16 CAFFE_ENFORCE_EQ(dim.size(), stride.size());
17 for (
auto i = 0; i < n; ++i) {
18 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&descs_[i]));
19 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
21 cudnnTypeWrapper<T>::type,
29 TensorDescriptors<T>::~TensorDescriptors() {
30 for (
auto desc : descs_) {
31 cudnnDestroyTensorDescriptor(desc);
37 RecurrentBaseOp<T>::RecurrentBaseOp(
38 const OperatorDef& operator_def,
40 : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
41 CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
42 CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
43 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
44 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
45 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&cxDesc_));
46 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hyDesc_));
47 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&cyDesc_));
51 RecurrentBaseOp<T>::~RecurrentBaseOp() {
52 CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropoutDesc_));
53 CUDNN_ENFORCE(cudnnDestroyRNNDescriptor(rnnDesc_));
54 CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(wDesc_));
55 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hxDesc_));
56 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(cxDesc_));
57 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(hyDesc_));
58 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(cyDesc_));
62 void RecurrentBaseOp<T>::initialize(
68 static_assert(
sizeof(T) == 4,
"");
69 CAFFE_ENFORCE_GE(input.ndim(), 3);
70 const int seqLength = input.dim(0);
71 const int batchSize = input.dim(1);
72 const int inputDim = input.dim(2);
73 const int hiddenSize = OperatorBase::GetSingleArgument<int>(
"hidden_size", 0);
74 CAFFE_ENFORCE_GT(hiddenSize, 0);
75 const auto bidirectional =
76 OperatorBase::GetSingleArgument<int>(
"bidirectional", 0);
77 CAFFE_ENFORCE(bidirectional == 0 || bidirectional == 1);
78 const auto numDirections = bidirectional == 1 ? 2 : 1;
79 const auto outputDim = hiddenSize * numDirections;
80 const auto rnnDirection =
81 bidirectional == 1 ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
82 const auto numLayers = OperatorBase::GetSingleArgument<int>(
"num_layers", 0);
83 CAFFE_ENFORCE_GT(numLayers, 0);
84 const auto& rnnModeStr =
85 OperatorBase::GetSingleArgument<string>(
"rnn_mode",
"");
86 CAFFE_ENFORCE(rnnModeStr ==
"lstm" || rnnModeStr ==
"gru");
87 const auto rnnMode = rnnModeStr ==
"lstm" ? CUDNN_LSTM : CUDNN_GRU;
88 const auto& rnnInputStr =
89 OperatorBase::GetSingleArgument<string>(
"input_mode",
"");
90 CAFFE_ENFORCE(rnnInputStr ==
"linear" || rnnInputStr ==
"skip");
92 rnnInputStr ==
"linear" ? CUDNN_LINEAR_INPUT : CUDNN_SKIP_INPUT;
99 OperatorBase::GetSingleArgument<float>(
"dropout", 1.0);
100 if (dropout_param < 1.0) {
101 CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
102 cudnn_wrapper_.inline_cudnn_handle(), &stateSize));
103 dropoutStates->Resize(std::vector<int>{
static_cast<int>(
105 CUDNN_ENFORCE(cudnnSetDropoutDescriptor(
107 cudnn_wrapper_.inline_cudnn_handle(),
109 dropoutStates->template mutable_data<T>(),
111 OperatorBase::GetSingleArgument<int>(
"seed", 0)));
118 #if CUDNN_VERSION_MIN(7, 0, 0) 119 CUDNN_ENFORCE(cudnnSetRNNDescriptor(
120 cudnn_wrapper_.inline_cudnn_handle(),
128 CUDNN_RNN_ALGO_STANDARD,
129 cudnnTypeWrapper<T>::type));
131 CUDNN_ENFORCE(cudnnSetRNNDescriptor(
139 cudnnTypeWrapper<T>::type));
144 xDesc_.reset(
new detail::TensorDescriptors<T>(
147 {batchSize, inputDim, 1},
153 yDesc_.reset(
new detail::TensorDescriptors<T>(
156 {batchSize, hiddenSize * numDirections, 1},
158 {numDirections * hiddenSize, 1, 1}));
161 output->Resize(std::vector<int>{seqLength, batchSize, outputDim});
167 const std::array<int, 3> dim{
168 numLayers * numDirections, batchSize, hiddenSize};
169 const std::array<int, 3> stride{batchSize * hiddenSize, hiddenSize, 1};
170 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
171 hxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
172 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
173 cxDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
174 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
175 hyDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
176 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
177 cyDesc_, cudnnTypeWrapper<T>::type, 3, dim.data(), stride.data()));
180 hiddenOutput->Resize(
181 std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
186 std::vector<int>{numLayers * numDirections, batchSize, hiddenSize});
193 CUDNN_ENFORCE(cudnnGetRNNParamsSize(
194 cudnn_wrapper_.inline_cudnn_handle(),
198 cudnnTypeWrapper<T>::type));
199 const std::array<int, 3> dims{
204 CUDNN_ENFORCE(cudnnSetFilterNdDescriptor(
205 wDesc_, cudnnTypeWrapper<T>::type, CUDNN_TENSOR_NCHW, 3, dims.data()));
210 CUDNN_ENFORCE(cudnnGetRNNWorkspaceSize(
211 cudnn_wrapper_.inline_cudnn_handle(),
219 template <
typename T>
220 bool RecurrentOp<T>::RunOnDevice() {
221 const int seqLength = Input(INPUT).dim32(0);
222 if (Input(INPUT).dims() != cachedInputDims_) {
225 Output(DROPOUT_STATES),
227 Output(HIDDEN_OUTPUT),
228 Output(CELL_OUTPUT));
229 cachedInputDims_ = Input(INPUT).dims();
234 CUDNN_ENFORCE(cudnnGetRNNParamsSize(
235 cudnn_wrapper_.inline_cudnn_handle(),
239 cudnnTypeWrapper<T>::type));
240 CAFFE_ENFORCE_EQ(Input(WEIGHT).nbytes(), weightsSize);
243 CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
244 cudnn_wrapper_.inline_cudnn_handle(),
250 ->Resize(std::vector<int>{
static_cast<int>(
251 reserveNbytes_ / 4)});
252 Output(RNN_SCRATCH)->template mutable_data<T>();
254 auto InputData = [
this](
int i) {
return this->Input(i).template data<T>(); };
255 auto OutputData = [
this](
int i) {
256 return this->Output(i)->template mutable_data<T>();
259 if (OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
260 cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
261 CUDNN_ENFORCE(cudnnRNNForwardInference(
262 state->cudnn_handle(),
268 InputData(HIDDEN_INPUT),
270 InputData(CELL_INPUT),
276 OutputData(HIDDEN_OUTPUT),
278 OutputData(CELL_OUTPUT),
279 state->workspace().get(cudnnWsNbytes_),
283 cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
284 CUDNN_ENFORCE(cudnnRNNForwardTraining(
285 state->cudnn_handle(),
291 InputData(HIDDEN_INPUT),
293 InputData(CELL_INPUT),
299 OutputData(HIDDEN_OUTPUT),
301 OutputData(CELL_OUTPUT),
302 state->workspace().get(cudnnWsNbytes_),
304 OutputData(RNN_SCRATCH),
312 template <
typename T>
313 bool RecurrentGradientOp<T>::RunOnDevice() {
314 const int seqLength = Input(INPUT).dim32(0);
315 if (Input(INPUT).dims() != cachedInputDims_) {
316 initialize(Input(INPUT), Output(DROPOUT_STATES));
317 cachedInputDims_ = Input(INPUT).dims();
319 CUDNN_ENFORCE(cudnnGetRNNTrainingReserveSize(
320 cudnn_wrapper_.inline_cudnn_handle(),
325 CAFFE_ENFORCE_EQ(reserveNbytes_, Input(RNN_SCRATCH).nbytes());
326 Output(GRAD_INPUT)->ResizeLike(Input(INPUT));
327 Output(GRAD_HIDDEN_INPUT)->ResizeLike(Input(HIDDEN_INPUT));
328 Output(GRAD_CELL_INPUT)->ResizeLike(Input(CELL_INPUT));
330 Output(GRAD_WEIGHT)->ResizeLike(Input(WEIGHT));
331 math::Set<T, CUDAContext>(
332 Output(GRAD_WEIGHT)->size(),
334 Output(GRAD_WEIGHT)->template mutable_data<T>(),
337 #if CUDNN_VERSION_MIN(6,0,0) 338 auto * reserve = Output(RNN_SCRATCH_OUT)->template mutable_data<T>();
340 const auto * reserve = Output(RNN_SCRATCH_OUT)->template data<T>();
342 auto InputData = [
this](
int i) {
return this->Input(i).template data<T>(); };
343 auto OutputData = [
this](
int i) {
344 return this->Output(i)->template mutable_data<T>();
347 cudnn_wrapper_.with_cudnn_state(0, [&](CuDNNState* state) {
348 CUDNN_ENFORCE(cudnnRNNBackwardData(
349 state->cudnn_handle(),
355 InputData(GRAD_OUTPUT),
365 InputData(HIDDEN_INPUT),
367 InputData(CELL_INPUT),
369 OutputData(GRAD_INPUT),
371 OutputData(GRAD_HIDDEN_INPUT),
373 OutputData(GRAD_CELL_INPUT),
374 state->workspace().get(cudnnWsNbytes_),
378 CUDNN_ENFORCE(cudnnRNNBackwardWeights(
379 state->cudnn_handle(),
385 InputData(HIDDEN_INPUT),
388 state->workspace().get(cudnnWsNbytes_),
400 template <
typename T, RecurrentParamOpMode mode>
401 bool RecurrentParamAccessOp<T, mode>::RunOnDevice() {
402 initialize(Input(0));
404 if (mode == SET_PARAM) {
406 CUDNN_ENFORCE(cudnnGetRNNParamsSize(
407 cudnn_wrapper_.inline_cudnn_handle(),
411 cudnnTypeWrapper<T>::type));
414 paramsSize / 4, Input(1).size(),
"Incorrect weight initialization");
417 int layer = OperatorBase::GetSingleArgument<int>(
"layer", 0);
418 std::string param_type =
419 OperatorBase::GetSingleArgument<string>(
"param_type",
"");
420 std::string input_type =
421 OperatorBase::GetSingleArgument<string>(
"input_type",
"");
424 std::map<string, int> weight_constants = {{
"input_gate_w", 0},
425 {
"forget_gate_w", 1},
427 {
"output_gate_w", 3}};
428 std::map<string, int> bias_constants = {{
"input_gate_b", 0},
429 {
"forget_gate_b", 1},
431 {
"output_gate_b", 3}};
432 if (bias_constants.find(param_type) != bias_constants.end()) {
433 int param_id = bias_constants[param_type] + 4 * (input_type ==
"recurrent");
435 cudnnFilterDescriptor_t biasDesc;
436 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&biasDesc));
439 CUDNN_ENFORCE(cudnnGetRNNLinLayerBiasParams(
440 cudnn_wrapper_.inline_cudnn_handle(),
445 Input(1).template data<T>(),
450 std::vector<int> biasDims(3);
452 cudnnTensorFormat_t tf;
454 CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
455 biasDesc, 3, &dt, &tf, &numBiasDims, biasDims.data()));
456 CAFFE_ENFORCE_EQ(numBiasDims, 3);
458 if (mode == SET_PARAM) {
460 biasDims[0] * biasDims[1] * biasDims[2], Input(2).size());
461 context_.template Copy<T, CUDAContext, CUDAContext>(
462 biasDims[0] * biasDims[1] * biasDims[2],
463 Input(2).template data<T>(),
464 static_cast<T*>(bias));
466 Output(0)->Resize(biasDims);
467 context_.template Copy<T, CUDAContext, CUDAContext>(
468 biasDims[0] * biasDims[1] * biasDims[2],
469 static_cast<T*
>(bias),
470 Output(0)->template mutable_data<T>());
472 }
else if (weight_constants.find(param_type) != weight_constants.end()) {
474 weight_constants[param_type] + 4 * (input_type ==
"recurrent");
475 cudnnFilterDescriptor_t matrixParamDesc;
476 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&matrixParamDesc));
478 CUDNN_ENFORCE(cudnnGetRNNLinLayerMatrixParams(
479 cudnn_wrapper_.inline_cudnn_handle(),
484 Input(1).template data<T>(),
489 std::vector<int> matDims(3);
491 cudnnTensorFormat_t tf;
493 CUDNN_ENFORCE(cudnnGetFilterNdDescriptor(
494 matrixParamDesc, 3, &dt, &tf, &numDims, matDims.data()));
495 CAFFE_ENFORCE_EQ(numDims, 3);
496 if (mode == SET_PARAM) {
497 CAFFE_ENFORCE_EQ(matDims[0] * matDims[1] * matDims[2], Input(2).size());
498 context_.template Copy<T, CUDAContext, CUDAContext>(
499 matDims[0] * matDims[1] * matDims[2],
500 Input(2).template data<T>(),
501 static_cast<T*>(pmatrix));
503 Output(0)->Resize(matDims);
504 context_.template Copy<T, CUDAContext, CUDAContext>(
505 matDims[0] * matDims[1] * matDims[2],
506 static_cast<T*
>(pmatrix),
507 Output(0)->template mutable_data<T>());
510 CAFFE_ENFORCE(
false,
"Unknown param type:", param_type);
516 REGISTER_CUDNN_OPERATOR(Recurrent, RecurrentOp<float>);
517 OPERATOR_SCHEMA(Recurrent).NumInputs(4).NumOutputs(5).SetDoc(R
"DOC( 519 Recurrent wraps the CuDNN R5 RNN implementation. See the CuDNN R5 520 documentation for more information. 522 In general, the implementation takes an input (TxNxD) tensor, the 523 hidden state input (NxD), the cell input (NxD), and a weight tensor 524 (effectively an opaque blob, where the size and layout is dictated by 527 The outputs are the output (again, TxNxD), the final hidden/cell 528 states (NxD). These can be reset (at sequence boundaries across 529 minibatches) by multiplying by zero. 531 The CuDNN arguments (hidden_size, bidirectional, num_layers, rnn_mode, 532 input_mode) are passed directly through to CuDNN. 535 REGISTER_CUDNN_OPERATOR(RecurrentGradient, RecurrentGradientOp<float>); 536 OPERATOR_SCHEMA(RecurrentGradient) 539 .AllowInplace({{4, 5}}); 541 REGISTER_CUDNN_OPERATOR( 543 RecurrentParamAccessOp<float, SET_PARAM>); 544 OPERATOR_SCHEMA(RecurrentParamSet) 547 .EnforceInplace({{1, 0}}) 548 .SetDoc("Set individual parameters of a recurrent net.")
549 .Arg(
"param_type", R
"DOC(Type of param to be set: 550 "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w" 551 "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b" 553 .Arg("input_type",
"'recurrent' or 'input'")
554 .Arg(
"layer",
"layer index (starting from 0)")
555 .Input(0,
"input", R
"DOC(Input blob. Needed for inferring the shapes. 556 A dummy tensor matching the input shape is ok.)DOC") 557 .Input(1, "all_params",
"Blob holding all the parameters")
558 .Input(2,
"param",
"Values for the specified parameter")
562 "Blob holding all the parameters (same as input(1))");
564 REGISTER_CUDNN_OPERATOR(
566 RecurrentParamAccessOp<float, GET_PARAM>);
567 OPERATOR_SCHEMA(RecurrentParamGet)
570 .SetDoc(
"Retrieve individual parameters of a recurrent net op.")
571 .Arg(
"param_type", R
"DOC(Type of param to be set: 572 "input_gate_w", "forget_gate_w", "cell_w", "output_gate_w" 573 "input_gate_b", "forget_gate_b", "cell_b", "output_gate_b" 575 .Arg("input_type",
"'recurrent' or 'input'")
576 .Arg(
"layer",
"layer index (starting from 0)")
577 .Input(0,
"input", R
"DOC(Input blob. Needed for inferring the shapes. 578 A dummy tensor matching the input shape is ok.)DOC") 579 .Input(1, "all_params",
"Blob holding all the parameters")
580 .Output(0,
"param",
"Blob holding the requested values");
583 using GradientMakerBase::GradientMakerBase;
584 vector<OperatorDef> GetGradientDefs()
override {
585 return SingleGradientDef(
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...