1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/types.h" 12 cudnn_wrapper_(&context_),
13 order_(StringToStorageOrder(
14 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))) {
15 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
16 CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&activ_desc_));
17 CUDNN_ENFORCE(cudnnSetActivationDescriptor(
18 activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0));
22 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
23 CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(activ_desc_));
27 bool DoRunWithType() {
28 const auto& X = Input(0);
38 if (X.dims() != cudnn_input_dims_) {
39 VLOG(1) <<
"Setting descriptors.";
40 cudnn_input_dims_ = X.dims();
41 int C = 1, H = 1, W = 1;
44 C = (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(3));
45 H = (order_ == StorageOrder::NCHW ? X.dim32(2) : X.dim32(1));
46 W = (order_ == StorageOrder::NCHW ? X.dim32(3) : X.dim32(2));
50 C = X.size() / X.dim32(0);
52 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
61 CUDNN_ENFORCE(cudnnActivationForward(
69 Y->template mutable_data<T>()));
73 bool RunOnDevice()
override {
75 const auto& X = Input(0);
79 if (X.IsType<
float>()) {
80 return DoRunWithType<float>();
81 }
else if (X.IsType<float16>()) {
82 return DoRunWithType<float16>();
84 LOG(FATAL) <<
"Unsupported input types";
91 cudnnTensorDescriptor_t data_desc_;
92 cudnnActivationDescriptor_t activ_desc_;
93 vector<TIndex> cudnn_input_dims_;
108 cudnn_wrapper_(&context_),
109 order_(StringToStorageOrder(
110 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))) {
111 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
112 CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&activ_desc_));
113 CUDNN_ENFORCE(cudnnSetActivationDescriptor(
114 activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0));
118 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
119 CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(activ_desc_));
122 template <
typename T>
123 bool DoRunWithType() {
124 const auto& Y = Input(0);
125 const auto& dY = Input(1);
126 auto* dX = Output(0);
130 dX->mutable_data<T>();
135 if (Y.dims() != cudnn_input_dims_) {
136 VLOG(1) <<
"Setting descriptors.";
137 cudnn_input_dims_ = Y.dims();
138 int C = 1, H = 1, W = 1;
141 C = (order_ == StorageOrder::NCHW ? Y.dim32(1) : Y.dim32(3));
142 H = (order_ == StorageOrder::NCHW ? Y.dim32(2) : Y.dim32(1));
143 W = (order_ == StorageOrder::NCHW ? Y.dim32(3) : Y.dim32(2));
147 C = Y.size() / Y.dim32(0);
149 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
158 CUDNN_ENFORCE(cudnnActivationBackward(
159 cudnn_wrapper_.inline_cudnn_handle(),
163 Y.template data<T>(),
165 dY.template data<T>(),
173 Y.template data<T>(),
176 dX->template mutable_data<T>()));
180 bool RunOnDevice()
override {
181 const auto& Y = Input(0);
182 auto* dX = Output(0);
185 if (Y.IsType<
float>()) {
186 return DoRunWithType<float>();
187 }
else if (Y.IsType<float16>()) {
188 return DoRunWithType<float16>();
190 LOG(FATAL) <<
"Unsupported input types";
197 cudnnTensorDescriptor_t data_desc_;
198 cudnnActivationDescriptor_t activ_desc_;
199 vector<TIndex> cudnn_input_dims_;
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread's cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...