1 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/context_gpu.h" 6 #include "caffe2/core/cudnn_wrappers.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 18 const std::vector<int>& dim,
19 const std::vector<int>& stride);
21 const cudnnTensorDescriptor_t* descs()
const {
26 std::vector<cudnnTensorDescriptor_t> descs_;
48 cudnnDropoutDescriptor_t dropoutDesc_;
49 cudnnRNNDescriptor_t rnnDesc_;
50 cudnnFilterDescriptor_t wDesc_;
51 cudnnTensorDescriptor_t hxDesc_;
52 cudnnTensorDescriptor_t cxDesc_;
53 cudnnTensorDescriptor_t hyDesc_;
54 cudnnTensorDescriptor_t cyDesc_;
56 std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
57 std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
59 std::vector<TIndex> cachedInputDims_;
60 size_t reserveNbytes_;
61 size_t cudnnWsNbytes_;
66 #define USE_RECURRENT_BASE_FUNCTIONS \ 67 USE_OPERATOR_FUNCTIONS(CUDAContext); \ 68 using RecurrentBaseOp<T>::cudnn_wrapper_; \ 69 using RecurrentBaseOp<T>::dropoutDesc_; \ 70 using RecurrentBaseOp<T>::rnnDesc_; \ 71 using RecurrentBaseOp<T>::wDesc_; \ 72 using RecurrentBaseOp<T>::hxDesc_; \ 73 using RecurrentBaseOp<T>::cxDesc_; \ 74 using RecurrentBaseOp<T>::hyDesc_; \ 75 using RecurrentBaseOp<T>::cyDesc_; \ 76 using RecurrentBaseOp<T>::xDesc_; \ 77 using RecurrentBaseOp<T>::yDesc_; \ 78 using RecurrentBaseOp<T>::cachedInputDims_; \ 79 using RecurrentBaseOp<T>::reserveNbytes_; \ 80 using RecurrentBaseOp<T>::cudnnWsNbytes_; \ 81 using RecurrentBaseOp<T>::initialize; 86 USE_RECURRENT_BASE_FUNCTIONS
90 bool RunOnDevice()
override;
93 INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
94 OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
97 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
99 template <
typename T, RecurrentParamOpMode mode>
102 USE_RECURRENT_BASE_FUNCTIONS
106 bool RunOnDevice()
override;
109 template <
typename T>
112 USE_RECURRENT_BASE_FUNCTIONS
116 bool RunOnDevice()
override;
141 #endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.