Caffe2 - C++ API
A deep learning, cross platform ML framework
local_response_normalization_op_cudnn.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/core/cudnn_wrappers.h"
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/types.h"
5 
6 namespace caffe2 {
7 
8 class CuDNNLRNOp final : public Operator<CUDAContext> {
9  public:
10  USE_OPERATOR_FUNCTIONS(CUDAContext);
11 
12  CuDNNLRNOp(const OperatorDef& operator_def, Workspace* ws)
13  : Operator<CUDAContext>(operator_def, ws),
14  cudnn_wrapper_(&context_),
15  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
16  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
17  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
18  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
19  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
20 
21  CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
22  CUDNN_ENFORCE(
23  cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
24  }
25 
26  ~CuDNNLRNOp() {
27  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
28  CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
29  }
30 
31  template <typename T, typename M>
32  bool DoRunWithType();
33 
34  bool RunOnDevice() override;
35 
36  protected:
37  CuDNNWrapper cudnn_wrapper_;
38  cudnnTensorDescriptor_t data_desc_;
39  cudnnLRNDescriptor_t norm_desc_;
40 
41  vector<TIndex> cudnn_input_dims_;
42 
43  const int size_;
44  const float alpha_;
45  const float beta_;
46  const float bias_;
47 
48  // Input: X, Output: Y
49 };
50 
51 class CuDNNLRNGradientOp final : public Operator<CUDAContext> {
52  public:
53  USE_OPERATOR_FUNCTIONS(CUDAContext);
54  CuDNNLRNGradientOp(const OperatorDef& operator_def, Workspace* ws)
55  : Operator<CUDAContext>(operator_def, ws),
56  cudnn_wrapper_(&context_),
57  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
58  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
59  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
60  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
61  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
62 
63  CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
64  CUDNN_ENFORCE(
65  cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
66  }
67 
69  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
70  CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
71  }
72 
73  template <typename T, typename M>
74  bool DoRunWithType();
75 
76  bool RunOnDevice() override;
77 
78  protected:
79  CuDNNWrapper cudnn_wrapper_;
80  cudnnTensorDescriptor_t data_desc_;
81  cudnnLRNDescriptor_t norm_desc_;
82 
83  vector<TIndex> cudnn_input_dims_;
84 
85  const int size_;
86  const float alpha_;
87  const float beta_;
88  const float bias_;
89 
90  // Input: X, Y, dY
91  // Output: dX
92 };
93 
94 template <typename T, typename M>
95 bool CuDNNLRNOp::DoRunWithType() {
96  const auto& X = Input(0);
97  auto* Y = Output(0);
98 
99  // Reshape tensor descriptors if necessary
100  if (X.dims() != cudnn_input_dims_) {
101  VLOG(1) << "Setting descriptors";
102  cudnn_input_dims_ = X.dims();
103  int C = 1, H = 1, W = 1;
104  // Normal 4-dimensional tensors for images.
105  C = X.dim32(1);
106  H = X.dim32(2);
107  W = X.dim32(3);
108  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
109  data_desc_,
110  GetCudnnTensorFormat(StorageOrder::NCHW),
112  X.dim32(0),
113  C,
114  H,
115  W));
116  }
117 
118  // now actually run the computation
119  CUDNN_ENFORCE(cudnnLRNCrossChannelForward(
120  cudnn_wrapper_.inline_cudnn_handle(),
121  norm_desc_,
122  CUDNN_LRN_CROSS_CHANNEL_DIM1,
124  data_desc_,
125  X.template data<T>(),
127  data_desc_,
128  Y->template mutable_data<T>()));
129 
130  return true;
131 }
132 
133 bool CuDNNLRNOp::RunOnDevice() {
134  // dispatch based on contents of tensor(s)
135  const auto& X = Input(0);
136  auto* Y = Output(0);
137  Y->ResizeLike(X);
138 
139  if (X.IsType<float>()) {
140  return DoRunWithType<float, float>();
141  } else if (X.IsType<float16>()) {
142  return DoRunWithType<float16, float>();
143  } else {
144  CAFFE_THROW("Unsupported input type");
145  }
146  return false;
147 }
148 
149 template <typename T, typename M>
150 bool CuDNNLRNGradientOp::DoRunWithType() {
151  const auto& X = Input(0);
152  const auto& Y = Input(1);
153  const auto& dY = Input(2);
154  auto* dX = Output(0);
155 
156  if (dY.dims() != cudnn_input_dims_) {
157  VLOG(1) << "Setting descriptors";
158  cudnn_input_dims_ = dY.dims();
159  int C = 1, H = 1, W = 1;
160  // Normal 4-dimensional tensors for images.
161  C = dY.dim32(1);
162  H = dY.dim32(2);
163  W = dY.dim32(3);
164  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
165  data_desc_,
166  GetCudnnTensorFormat(StorageOrder::NCHW),
168  dY.dim32(0),
169  C,
170  H,
171  W));
172  }
173 
174  // run the computation
175  CUDNN_ENFORCE(cudnnLRNCrossChannelBackward(
176  cudnn_wrapper_.inline_cudnn_handle(),
177  norm_desc_,
178  CUDNN_LRN_CROSS_CHANNEL_DIM1,
180  data_desc_,
181  Y.template data<T>(),
182  data_desc_,
183  dY.template data<T>(),
184  data_desc_,
185  X.template data<T>(),
187  data_desc_,
188  dX->template mutable_data<T>()));
189  return true;
190 }
191 
192 bool CuDNNLRNGradientOp::RunOnDevice() {
193  // dispatch based on contents of tensor(s)
194  const auto& X = Input(0);
195  const auto& Y = Input(1);
196  const auto& dY = Input(2);
197  auto* dX = Output(0);
198 
199  dX->ResizeLike(dY);
200 
201  if (dY.IsType<float>()) {
202  return DoRunWithType<float, float>();
203  } else if (dY.IsType<float16>()) {
204  return DoRunWithType<float16, float>();
205  } else {
206  CAFFE_THROW("Unsupported input type");
207  }
208 
209  return false;
210 }
211 
212 namespace {
213 REGISTER_CUDNN_OPERATOR(LRN, CuDNNLRNOp);
214 REGISTER_CUDNN_OPERATOR(LRNGradient, CuDNNLRNGradientOp);
215 }
216 
217 }; // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:183
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:111