Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_op_cudnn.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
2 #define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/context_gpu.h"
6 #include "caffe2/core/cudnn_wrappers.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 namespace detail {
12 
13 template <typename T>
15  public:
17  size_t n,
18  const std::vector<int>& dim,
19  const std::vector<int>& stride);
21  const cudnnTensorDescriptor_t* descs() const {
22  return descs_.data();
23  }
24 
25  private:
26  std::vector<cudnnTensorDescriptor_t> descs_;
27 };
28 
29 } // namespace detail
30 
31 template <typename T>
32 class RecurrentBaseOp : public Operator<CUDAContext> {
33  public:
34  USE_OPERATOR_FUNCTIONS(CUDAContext);
35  RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
36  virtual ~RecurrentBaseOp();
37 
38  protected:
39  void initialize(
40  const Tensor<CUDAContext>& input,
41  Tensor<CUDAContext>* dropoutStates = nullptr,
42  // If passed, reshapes to the appropriate size
43  Tensor<CUDAContext>* output = nullptr,
44  Tensor<CUDAContext>* hiddenOutput = nullptr,
45  Tensor<CUDAContext>* cellOutput = nullptr);
46 
47  CuDNNWrapper cudnn_wrapper_;
48  cudnnDropoutDescriptor_t dropoutDesc_;
49  cudnnRNNDescriptor_t rnnDesc_;
50  cudnnFilterDescriptor_t wDesc_;
51  cudnnTensorDescriptor_t hxDesc_;
52  cudnnTensorDescriptor_t cxDesc_;
53  cudnnTensorDescriptor_t hyDesc_;
54  cudnnTensorDescriptor_t cyDesc_;
55 
56  std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
57  std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
58 
59  std::vector<TIndex> cachedInputDims_;
60  size_t reserveNbytes_;
61  size_t cudnnWsNbytes_;
62 
63  private:
64 };
65 
66 #define USE_RECURRENT_BASE_FUNCTIONS \
67  USE_OPERATOR_FUNCTIONS(CUDAContext); \
68  using RecurrentBaseOp<T>::cudnn_wrapper_; \
69  using RecurrentBaseOp<T>::dropoutDesc_; \
70  using RecurrentBaseOp<T>::rnnDesc_; \
71  using RecurrentBaseOp<T>::wDesc_; \
72  using RecurrentBaseOp<T>::hxDesc_; \
73  using RecurrentBaseOp<T>::cxDesc_; \
74  using RecurrentBaseOp<T>::hyDesc_; \
75  using RecurrentBaseOp<T>::cyDesc_; \
76  using RecurrentBaseOp<T>::xDesc_; \
77  using RecurrentBaseOp<T>::yDesc_; \
78  using RecurrentBaseOp<T>::cachedInputDims_; \
79  using RecurrentBaseOp<T>::reserveNbytes_; \
80  using RecurrentBaseOp<T>::cudnnWsNbytes_; \
81  using RecurrentBaseOp<T>::initialize;
82 
83 template <typename T>
84 class RecurrentOp : public RecurrentBaseOp<T> {
85  public:
86  USE_RECURRENT_BASE_FUNCTIONS
87  RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
88  : RecurrentBaseOp<T>(operator_def, ws) {}
89 
90  bool RunOnDevice() override;
91 
92  protected:
93  INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
94  OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
95 };
96 
97 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
98 
99 template <typename T, RecurrentParamOpMode mode>
101  public:
102  USE_RECURRENT_BASE_FUNCTIONS
103  RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
104  : RecurrentBaseOp<T>(operator_def, ws) {}
105 
106  bool RunOnDevice() override;
107 };
108 
109 template <typename T>
111  public:
112  USE_RECURRENT_BASE_FUNCTIONS
113  RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
114  : RecurrentBaseOp<T>(operator_def, ws) {}
115 
116  bool RunOnDevice() override;
117 
118  protected:
119  INPUT_TAGS(
120  INPUT,
121  HIDDEN_INPUT,
122  CELL_INPUT,
123  WEIGHT,
124  RNN_SCRATCH,
125  OUTPUT,
126  GRAD_OUTPUT,
127  GRAD_HIDDEN_OUTPUT,
128  GRAD_CELL_OUTPUT);
129  OUTPUT_TAGS(
130  GRAD_INPUT,
131  GRAD_HIDDEN_INPUT,
132  GRAD_CELL_INPUT,
133  GRAD_WEIGHT,
134  DROPOUT_STATES,
135  RNN_SCRATCH_OUT);
136 };
137 
138 
139 } // namespace caffe2
140 
141 #endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.