3 #include "caffe2/core/context_gpu.h" 4 #include "caffe2/core/cudnn_wrappers.h" 5 #include "caffe2/operators/spatial_batch_norm_op.h" 6 #include "caffe2/utils/math.h" 10 static_assert(CUDNN_VERSION >= 5000,
11 "CudnnSpatialBN requires cudnn version 5.0 or above.");
20 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
21 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bn_param_desc_));
22 if (epsilon_ <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
23 LOG(ERROR) <<
"Provided epsilon is smaller than " 24 <<
"CUDNN_BN_MIN_EPSILON. Setting it to " 25 <<
"CUDNN_BN_MIN_EPSILON instead.";
27 epsilon_ = std::max(epsilon_, CUDNN_BN_MIN_EPSILON);
28 #if CUDNN_VERSION_MIN(7,0,0) 29 mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
31 mode_ = CUDNN_BATCHNORM_SPATIAL;
36 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
37 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bn_param_desc_));
40 template <
typename T,
typename M>
42 bool RunOnDevice()
override;
46 cudnnTensorDescriptor_t data_desc_;
47 cudnnTensorDescriptor_t bn_param_desc_;
48 vector<TIndex> cudnn_input_dims_;
50 cudnnBatchNormMode_t mode_;
58 cudnn_wrapper_(&context_) {
59 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
60 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bn_param_desc_));
61 if (epsilon_ <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
62 LOG(ERROR) <<
"Provided epsilon is smaller than " 63 <<
"CUDNN_BN_MIN_EPSILON. Setting it to " 64 <<
"CUDNN_BN_MIN_EPSILON instead.";
66 epsilon_ = std::max(epsilon_, CUDNN_BN_MIN_EPSILON);
67 #if CUDNN_VERSION_MIN(7,0,0) 68 mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
70 mode_ = CUDNN_BATCHNORM_SPATIAL;
75 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
76 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bn_param_desc_));
79 template <
typename T,
typename M>
82 bool RunOnDevice()
override;
86 cudnnTensorDescriptor_t data_desc_;
87 cudnnTensorDescriptor_t bn_param_desc_;
88 vector<TIndex> cudnn_input_dims_;
90 cudnnBatchNormMode_t mode_;
98 template <
typename T,
typename M>
99 bool CudnnSpatialBNOp::DoRunWithType() {
104 const auto& X = Input(INPUT);
105 const auto& scale = Input(SCALE);
106 const auto& bias = Input(BIAS);
108 CAFFE_ENFORCE_GE(X.ndim(), 3);
109 const int N = X.dim32(0);
110 const int C = X.ndim() > 3
111 ? (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(X.ndim() - 1))
112 : (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(2));
113 const int H = (order_ == StorageOrder::NCHW ? X.dim32(2) : X.dim32(1));
114 const int W = X.ndim() > 3
115 ? (order_ == StorageOrder::NCHW ? X.dim32(3) : X.dim32(2))
117 const int D = X.ndim() > 4
118 ? (order_ == StorageOrder::NCHW ? X.dim32(4) : X.dim32(3))
120 CAFFE_ENFORCE_EQ(scale.ndim(), 1);
121 CAFFE_ENFORCE_EQ(bias.ndim(), 1);
122 CAFFE_ENFORCE_EQ(scale.dim32(0), C);
123 CAFFE_ENFORCE_EQ(bias.dim32(0), C);
125 if (X.dims() != cudnn_input_dims_) {
126 VLOG(1) <<
"Setting descriptors.";
127 cudnn_input_dims_ = X.dims();
128 if (order_ == StorageOrder::NCHW) {
129 vector<int> dims = {N, C, H, W, D};
130 vector<int> strides = {C * H * W * D, H * W * D, W * D, D, 1};
131 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
134 X.ndim() > 3 ? X.ndim() : 4,
138 vector<int> dims = {N, C, H, W, D};
139 vector<int> strides = {H * W * D * C, 1, W * D * C, D * C, C};
140 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
143 X.ndim() > 3 ? X.ndim() : 4,
147 CUDNN_ENFORCE(cudnnDeriveBNTensorDescriptor(
148 bn_param_desc_, data_desc_, mode_));
154 const auto& est_mean = Input(EST_MEAN);
155 const auto& est_var = Input(EST_VAR);
156 CAFFE_ENFORCE_EQ(est_mean.ndim(), 1);
157 CAFFE_ENFORCE_EQ(est_var.ndim(), 1);
158 CAFFE_ENFORCE_EQ(est_mean.dim32(0), C);
159 CAFFE_ENFORCE_EQ(est_var.dim32(0), C);
161 auto* Y = Output(OUTPUT);
163 CUDNN_ENFORCE(cudnnBatchNormalizationForwardInference(
164 cudnn_wrapper_.inline_cudnn_handle(),
166 CUDNN_BATCHNORM_SPATIAL,
170 X.template data<T>(),
172 Y->template mutable_data<T>(),
174 scale.template data<BNParamType>(),
175 bias.template data<BNParamType>(),
176 est_mean.template data<BNParamType>(),
177 est_var.template data<BNParamType>(),
181 auto* Y = Output(OUTPUT);
185 auto* running_mean = Output(RUNNING_MEAN);
186 auto* running_var = Output(RUNNING_VAR);
187 double this_factor = 1. - momentum_;
188 BNParamType* running_mean_data =
nullptr;
189 BNParamType* running_var_data =
nullptr;
190 if (!running_mean->size()) {
193 VLOG(1) <<
"Initializing running mean and var.";
195 running_mean->Resize(C);
196 running_var->Resize(C);
197 running_mean_data = running_mean->template mutable_data<BNParamType>();
198 running_var_data = running_var->template mutable_data<BNParamType>();
203 math::Set<BNParamType, CUDAContext>(C, 0, running_mean_data, &context_);
204 math::Set<BNParamType, CUDAContext>(C, 0, running_var_data, &context_);
207 CAFFE_ENFORCE_EQ(running_mean->ndim(), 1);
208 CAFFE_ENFORCE_EQ(running_var->ndim(), 1);
209 CAFFE_ENFORCE_EQ(running_mean->dim32(0), C);
210 CAFFE_ENFORCE_EQ(running_var->dim32(0), C);
211 running_mean_data = running_mean->template mutable_data<BNParamType>();
212 running_var_data = running_var->template mutable_data<BNParamType>();
215 auto* save_mean = Output(SAVED_MEAN);
216 auto* save_var = Output(SAVED_INV_VAR);
217 save_mean->Resize(C);
219 void* save_mean_data = save_mean->template mutable_data<BNParamType>();
220 void* save_var_data = save_var->template mutable_data<BNParamType>();
222 CUDNN_ENFORCE(cudnnBatchNormalizationForwardTraining(
223 cudnn_wrapper_.inline_cudnn_handle(),
228 X.template data<T>(),
230 Y->template mutable_data<T>(),
232 scale.template data<BNParamType>(),
233 bias.template data<BNParamType>(),
244 bool CudnnSpatialBNOp::RunOnDevice() {
245 if (Input(0).IsType<float>()) {
246 return DoRunWithType<float,float>();
247 }
else if (Input(0).IsType<float16>()) {
248 return DoRunWithType<float16,float>();
250 LOG(FATAL) <<
"Unsupported input types";
255 template <
typename T,
typename M>
256 bool CudnnSpatialBNGradientOp::DoRunWithType() {
260 const auto& X = Input(INPUT);
261 const auto& scale = Input(SCALE);
262 const auto& dY = Input(OUTPUT_GRAD);
264 CAFFE_ENFORCE_GE(X.ndim(), 3);
265 const int N = X.dim32(0);
266 const int C = X.ndim() > 3
267 ? (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(X.ndim() - 1))
268 : (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(2));
269 const int H = (order_ == StorageOrder::NCHW ? X.dim32(2) : X.dim32(1));
270 const int W = X.ndim() > 3
271 ? (order_ == StorageOrder::NCHW ? X.dim32(3) : X.dim32(2))
273 const int D = X.ndim() > 4
274 ? (order_ == StorageOrder::NCHW ? X.dim32(4) : X.dim32(3))
276 CAFFE_ENFORCE_EQ(scale.ndim(), 1);
277 CAFFE_ENFORCE_EQ(scale.dim32(0), C);
279 if (X.dims() != cudnn_input_dims_) {
280 if (order_ == StorageOrder::NCHW) {
281 vector<int> dims = {N, C, H, W, D};
282 vector<int> strides = {C * H * W * D, H * W * D, W * D, D, 1};
283 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
286 X.ndim() > 3 ? X.ndim() : 4,
290 vector<int> dims = {N, C, H, W, D};
291 vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C};
292 CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
295 X.ndim() > 3 ? X.ndim() : 4,
299 CUDNN_ENFORCE(cudnnDeriveBNTensorDescriptor(
300 bn_param_desc_, data_desc_, mode_));
303 auto* dX = Output(INPUT_GRAD);
304 auto* dScale = Output(SCALE_GRAD);
305 auto* dBias = Output(BIAS_GRAD);
307 dScale->ResizeLike(scale);
308 dBias->ResizeLike(scale);
310 const auto& saved_mean = Input(SAVED_MEAN);
311 const auto& saved_var = Input(SAVED_INV_VAR);
312 const void* saved_mean_data = saved_mean.template data<BNParamType>();
313 const void* saved_var_data = saved_var.template data<BNParamType>();
315 CUDNN_ENFORCE(cudnnBatchNormalizationBackward(
316 cudnn_wrapper_.inline_cudnn_handle(),
323 X.template data<T>(),
325 dY.template data<T>(),
327 dX->template mutable_data<T>(),
329 scale.template data<BNParamType>(),
330 dScale->template mutable_data<BNParamType>(),
331 dBias->template mutable_data<BNParamType>(),
338 bool CudnnSpatialBNGradientOp::RunOnDevice() {
339 if (Input(0).IsType<float>()) {
340 return DoRunWithType<float,float>();
341 }
else if (Input(0).IsType<float16>()) {
342 return DoRunWithType<float16,float>();
344 LOG(FATAL) <<
"Unsupported input types";
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...