1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/types.h" 14 cudnn_wrapper_(&context_),
15 size_(OperatorBase::GetSingleArgument<int>(
"size", 0)),
16 alpha_(OperatorBase::GetSingleArgument<float>(
"alpha", 0)),
17 beta_(OperatorBase::GetSingleArgument<float>(
"beta", 0)),
18 bias_(OperatorBase::GetSingleArgument<float>(
"bias", 1)) {
19 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
21 CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
23 cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
27 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
28 CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
31 template <
typename T,
typename M>
34 bool RunOnDevice()
override;
38 cudnnTensorDescriptor_t data_desc_;
39 cudnnLRNDescriptor_t norm_desc_;
41 vector<TIndex> cudnn_input_dims_;
56 cudnn_wrapper_(&context_),
57 size_(OperatorBase::GetSingleArgument<int>(
"size", 0)),
58 alpha_(OperatorBase::GetSingleArgument<float>(
"alpha", 0)),
59 beta_(OperatorBase::GetSingleArgument<float>(
"beta", 0)),
60 bias_(OperatorBase::GetSingleArgument<float>(
"bias", 1)) {
61 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
63 CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
65 cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
69 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
70 CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
73 template <
typename T,
typename M>
76 bool RunOnDevice()
override;
80 cudnnTensorDescriptor_t data_desc_;
81 cudnnLRNDescriptor_t norm_desc_;
83 vector<TIndex> cudnn_input_dims_;
94 template <
typename T,
typename M>
95 bool CuDNNLRNOp::DoRunWithType() {
96 const auto& X = Input(0);
100 if (X.dims() != cudnn_input_dims_) {
101 VLOG(1) <<
"Setting descriptors";
102 cudnn_input_dims_ = X.dims();
103 int C = 1, H = 1, W = 1;
108 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
119 CUDNN_ENFORCE(cudnnLRNCrossChannelForward(
120 cudnn_wrapper_.inline_cudnn_handle(),
122 CUDNN_LRN_CROSS_CHANNEL_DIM1,
125 X.template data<T>(),
128 Y->template mutable_data<T>()));
133 bool CuDNNLRNOp::RunOnDevice() {
135 const auto& X = Input(0);
139 if (X.IsType<
float>()) {
140 return DoRunWithType<float, float>();
141 }
else if (X.IsType<float16>()) {
142 return DoRunWithType<float16, float>();
144 CAFFE_THROW(
"Unsupported input type");
149 template <
typename T,
typename M>
150 bool CuDNNLRNGradientOp::DoRunWithType() {
151 const auto& X = Input(0);
152 const auto& Y = Input(1);
153 const auto& dY = Input(2);
154 auto* dX = Output(0);
156 if (dY.dims() != cudnn_input_dims_) {
157 VLOG(1) <<
"Setting descriptors";
158 cudnn_input_dims_ = dY.dims();
159 int C = 1, H = 1, W = 1;
164 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
175 CUDNN_ENFORCE(cudnnLRNCrossChannelBackward(
176 cudnn_wrapper_.inline_cudnn_handle(),
178 CUDNN_LRN_CROSS_CHANNEL_DIM1,
181 Y.template data<T>(),
183 dY.template data<T>(),
185 X.template data<T>(),
188 dX->template mutable_data<T>()));
192 bool CuDNNLRNGradientOp::RunOnDevice() {
194 const auto& X = Input(0);
195 const auto& Y = Input(1);
196 const auto& dY = Input(2);
197 auto* dX = Output(0);
201 if (dY.IsType<
float>()) {
202 return DoRunWithType<float, float>();
203 }
else if (dY.IsType<float16>()) {
204 return DoRunWithType<float16, float>();
206 CAFFE_THROW(
"Unsupported input type");
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...