Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_op_cudnn.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/core/cudnn_wrappers.h"
3 #include "caffe2/operators/conv_op_cache_cudnn.h"
4 #include "caffe2/operators/conv_transpose_op.h"
5 #include "caffe2/operators/op_utils_cudnn.h"
6 
7 namespace caffe2 {
8 
9 class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
10  public:
11  CudnnConvTransposeOpBase(const OperatorDef& operator_def, Workspace* ws)
12  : ConvTransposeUnpoolBase<CUDAContext>(operator_def, ws),
13  cudnn_wrapper_(&context_),
14  cudnn_ws_nbytes_limit_(OperatorBase::GetSingleArgument<size_t>(
15  "ws_nbytes_limit",
16  kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
17  exhaustive_search_(
18  OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)),
19  deterministic_(
20  OperatorBase::GetSingleArgument<int>("deterministic", 0)),
21  cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0)),
22  force_algo_(OperatorBase::GetRepeatedArgument<int>(
23  "force_algo",
24  vector<int>{-1, -1, -1})),
25  enable_tensor_core_(
26  OperatorBase::GetSingleArgument<bool>("enable_tensor_core", 1)) {
27  CAFFE_ENFORCE(!deterministic_ || !exhaustive_search_);
28 
29  bool individual_force_algo = OperatorBase::HasArgument("force_algo_fwd") ||
30  OperatorBase::HasArgument("force_algo_dgrad") ||
31  OperatorBase::HasArgument("force_algo_wgrad");
32  if (OperatorBase::HasArgument("force_algo")) {
33  CAFFE_ENFORCE(
34  !individual_force_algo,
35  "Cannot specify both force_algo and any of",
36  "force_algo_fwd, force_algo_dgrad, force_algo_wgrad");
37  } else {
38  force_algo_ = std::vector<int>{-1, -1, -1};
39  force_algo_[ALGO_FWD] =
40  OperatorBase::GetSingleArgument<int>("force_algo_fwd", -1);
41  force_algo_[ALGO_DGRAD] =
42  OperatorBase::GetSingleArgument<int>("force_algo_dgrad", -1);
43  force_algo_[ALGO_WGRAD] =
44  OperatorBase::GetSingleArgument<int>("force_algo_wgrad", -1);
45  }
46 
47  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bottom_desc_));
48  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
49  if (InputSize() == 3) {
50  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
51  }
52  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
53  CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
54  }
55 
57  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bottom_desc_));
58  CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
59  if (InputSize() == 3) {
60  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
61  }
62  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
63  CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
64  }
65 
66  protected:
67  vector<TIndex> cudnn_input_dims_;
68  vector<TIndex> cudnn_filter_dims_;
69 
70  CuDNNWrapper cudnn_wrapper_;
71  cudnnTensorDescriptor_t bottom_desc_;
72  cudnnFilterDescriptor_t filter_desc_;
73  cudnnTensorDescriptor_t bias_desc_;
74  cudnnTensorDescriptor_t top_desc_;
75  cudnnConvolutionDescriptor_t conv_desc_;
76  const size_t cudnn_ws_nbytes_limit_;
77  size_t cudnn_ws_nbytes_;
78  bool exhaustive_search_;
79  bool deterministic_;
80  size_t cudnn_state_;
81  vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
82  bool enable_tensor_core_;
83 };
84 
85 template <typename T>
87  public:
88  CudnnConvTransposeOp(const OperatorDef& operator_def, Workspace* ws)
89  : CudnnConvTransposeOpBase(operator_def, ws) {}
90 
92 
93  bool RunOnDevice() override;
94 
95  private:
97  cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
98  // Input: X, W, b
99  // Output: Y
100  INPUT_TAGS(INPUT, FILTER, BIAS);
101 };
102 
103 template <typename T>
105  public:
106  CudnnConvTransposeGradientOp(const OperatorDef& operator_def, Workspace* ws)
107  : CudnnConvTransposeOpBase(operator_def, ws),
108  no_bias_(OperatorBase::GetSingleArgument<bool>("no_bias", false)) {
109  CAFFE_ENFORCE(
110  !(no_bias_ && OutputSize() == 3),
111  "If bias is not present, you should not have 3 grad output.");
112  }
113 
115 
116  bool RunOnDevice() override;
117 
118  private:
119  cudnnConvolutionFwdAlgo_t algo_;
120  cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
121  AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_algo_cache_;
123  const bool no_bias_;
124  // input: X, W, dY
125  // output: dW, optionally db and dX
126  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
127  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
128 };
129 
131 // Implementations
133 
134 template <typename T>
136  auto& X = Input(INPUT);
137  auto& filter = Input(FILTER);
138  auto* Y = Output(0);
139  int C = 0;
140  switch (order_) {
141  case StorageOrder::NHWC:
142  C = filter.dim32(3);
143  break;
144  case StorageOrder::NCHW:
145  C = filter.dim32(1);
146  break;
147  default:
148  LOG(FATAL) << "Unknown storage order: " << order_;
149  }
151 
152  int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
153  switch (order_) {
154  case StorageOrder::NHWC:
155  N = X.dim32(0);
156  H = X.dim32(1);
157  W = X.dim32(2);
158  M = X.dim32(3);
159  H_out = Y->dim32(1);
160  W_out = Y->dim32(2);
161  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
162  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
163  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
164  CAFFE_ENFORCE_EQ(filter.dim32(3), C);
165  break;
166  case StorageOrder::NCHW:
167  N = X.dim32(0);
168  M = X.dim32(1);
169  H = X.dim32(2);
170  W = X.dim32(3);
171  H_out = Y->dim32(2);
172  W_out = Y->dim32(3);
173  CAFFE_ENFORCE_EQ(filter.dim32(1), C);
174  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
175  CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
176  break;
177  default:
178  LOG(FATAL) << "Unknown storage order: " << order_;
179  }
180 
181  if (InputSize() == 3) {
182  auto& bias = Input(BIAS);
183  CAFFE_ENFORCE_EQ(bias.ndim(), 1);
184  CAFFE_ENFORCE_EQ(bias.dim32(0), C);
185  }
186 
187  // Set up the cudnn algorithms & workspace if necessary
188  bool input_changed = (X.dims() != cudnn_input_dims_);
189  bool filter_changed = (filter.dims() != cudnn_filter_dims_);
190 
191  if (input_changed || filter_changed) {
192  VLOG(1) << "Changing the cudnn descriptor configurations.";
193  if (input_changed) {
194  cudnn_input_dims_ = X.dims();
195  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
196  bottom_desc_,
197  GetCudnnTensorFormat(order_),
199  N,
200  M,
201  H,
202  W));
203  }
204  if (filter_changed) {
205  cudnn_filter_dims_ = filter.dims();
206  CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
207  filter_desc_,
209  GetCudnnTensorFormat(order_),
210  M,
211  C,
212  kernel_h(),
213  kernel_w()));
214  if (InputSize() == 3) {
215  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
216  bias_desc_,
217  GetCudnnTensorFormat(order_),
219  1,
220  C,
221  1,
222  1));
223  }
224  }
225  // Set the output
226  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
227  top_desc_,
228  GetCudnnTensorFormat(order_),
230  N,
231  C,
232  H_out,
233  W_out));
234  // Set the convolution descriptor
235  CAFFE_ENFORCE_EQ(
236  pad_t(),
237  pad_b(),
238  "The current padding scheme leads to unequal padding on the top and "
239  "bottom, which is not supported by cudnn.");
240  CAFFE_ENFORCE_EQ(
241  pad_l(),
242  pad_r(),
243  "The current padding scheme leads to unequal padding on the left "
244  "and right, which is not supported by cudnn.");
245  // Set the convolution descriptor
246 #if CUDNN_VERSION_MIN(6,0,0)
247  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
248  conv_desc_,
249  pad_t(),
250  pad_l(),
251  stride_h(),
252  stride_w(),
253  1,
254  1,
255  CUDNN_CROSS_CORRELATION,
257 #else
258  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
259  conv_desc_,
260  pad_t(),
261  pad_l(),
262  stride_h(),
263  stride_w(),
264  1,
265  1,
266  CUDNN_CROSS_CORRELATION));
267 #endif
268 #if CUDNN_VERSION_MIN(7, 0, 0)
269  // enable TensorCore math if desired
270  enable_tensor_core_ &= TensorCoreAvailable();
271  if (enable_tensor_core_) {
272  CUDNN_ENFORCE(
273  cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
274  }
275 #endif
276  if (force_algo_[ALGO_DGRAD] >= 0) {
277  bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
278  } else if (deterministic_) {
279  bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
280  } else if (exhaustive_search_) {
281  bwd_data_algo_ =
282  data_algo_cache_.getAlgorithm(X.dims(), filter.dims(), 0, [&]() {
283  int returned_algo_count;
284  std::array<
285  cudnnConvolutionBwdDataAlgoPerf_t,
286  kNUM_CUDNN_BWD_DATA_ALGS>
287  data_perf_stat;
288  cudnn_wrapper_.with_cudnn_state(
289  cudnn_state_, [&](CuDNNState* state) {
290  state->workspace().reset();
291  CUDNN_ENFORCE(cudnnFindConvolutionBackwardDataAlgorithm(
292  state->cudnn_handle(),
293  filter_desc_,
294  bottom_desc_,
295  conv_desc_,
296  top_desc_,
297  kNUM_CUDNN_BWD_DATA_ALGS,
298  &returned_algo_count,
299  data_perf_stat.data()));
300  });
301 
302  LogCuDNNPerfStats(data_perf_stat, returned_algo_count);
303  return data_perf_stat[0].algo;
304  });
305  } else {
306  CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataAlgorithm(
307  cudnn_wrapper_.inline_cudnn_handle(),
308  filter_desc_,
309  bottom_desc_,
310  conv_desc_,
311  top_desc_,
312  CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
313  cudnn_ws_nbytes_limit_,
314  &bwd_data_algo_));
315  }
316 
317  size_t bwd_data_ws_size;
318  CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataWorkspaceSize(
319  cudnn_wrapper_.inline_cudnn_handle(),
320  filter_desc_,
321  bottom_desc_,
322  conv_desc_,
323  top_desc_,
324  bwd_data_algo_,
325  &bwd_data_ws_size));
326  cudnn_ws_nbytes_ = bwd_data_ws_size;
327  VLOG(1) << "CuDNN algorithm: " << bwd_data_algo_;
328  VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
329  }
330 
331  // Now, actually run the computation.
332  // Filter
333  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
334  CUDNN_ENFORCE(cudnnConvolutionBackwardData(
335  state->cudnn_handle(),
337  filter_desc_,
338  filter.template data<T>(),
339  bottom_desc_,
340  X.template data<T>(),
341  conv_desc_,
342  bwd_data_algo_,
343  state->workspace().get(cudnn_ws_nbytes_),
344  cudnn_ws_nbytes_,
346  top_desc_,
347  Y->template mutable_data<T>()));
348  });
349  // Bias
350  if (InputSize() == 3) {
351  CUDNN_ENFORCE(cudnnAddTensor(
352  cudnn_wrapper_.inline_cudnn_handle(),
354  bias_desc_,
355  Input(BIAS).template data<T>(),
357  top_desc_,
358  Y->template mutable_data<T>()));
359  }
360  // Done.
361  return true;
362 }
363 
364 // TODO(Yangqing): a lot of the function contents are very similar. Consider
365 // consolidating them.
366 template <typename T>
368  auto& X = Input(INPUT);
369  auto& filter = Input(FILTER);
370  auto& dY = Input(OUTPUT_GRAD);
371  auto* dfilter = Output(FILTER_GRAD);
372  CAFFE_ENFORCE_EQ(X.ndim(), 4);
373  CAFFE_ENFORCE_EQ(filter.ndim(), 4);
374  int C = 0;
375  switch (order_) {
376  case StorageOrder::NHWC:
377  C = filter.dim32(3);
378  break;
379  case StorageOrder::NCHW:
380  C = filter.dim32(1);
381  break;
382  default:
383  LOG(FATAL) << "Unknown storage order: " << order_;
384  }
385 
386  int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
387  switch (order_) {
388  case StorageOrder::NHWC:
389  N = X.dim32(0);
390  H = X.dim32(1);
391  W = X.dim32(2);
392  M = X.dim32(3);
393  H_out = dY.dim32(1);
394  W_out = dY.dim32(2);
395  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
396  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
397  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
398  CAFFE_ENFORCE_EQ(filter.dim32(3), C);
399  break;
400  case StorageOrder::NCHW:
401  N = X.dim32(0);
402  M = X.dim32(1);
403  H = X.dim32(2);
404  W = X.dim32(3);
405  H_out = dY.dim32(2);
406  W_out = dY.dim32(3);
407  CAFFE_ENFORCE_EQ(filter.dim32(1), C);
408  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
409  CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
410  break;
411  default:
412  LOG(FATAL) << "Unknown storage order: " << order_;
413  }
414  // Since we only handle LegacyPadding::NOTSET, we don't need to
415  // compute padding.
416  dfilter->ResizeLike(filter);
417 
418  // Set up the cudnn algorithms & workspace if necessary
419  bool input_changed = (X.dims() != cudnn_input_dims_);
420  bool filter_changed = (filter.dims() != cudnn_filter_dims_);
421  if (input_changed || filter_changed) {
422  VLOG(1) << "Changing the cudnn descriptor configurations.";
423  if (input_changed) {
424  cudnn_input_dims_ = X.dims();
425  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
426  bottom_desc_,
427  GetCudnnTensorFormat(order_),
429  N,
430  M,
431  H,
432  W));
433  }
434  if (filter_changed) {
435  cudnn_filter_dims_ = filter.dims();
436  CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
437  filter_desc_,
439  GetCudnnTensorFormat(order_),
440  M,
441  C,
442  kernel_h(),
443  kernel_w()));
444  if (!no_bias_) {
445  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
446  bias_desc_,
447  GetCudnnTensorFormat(order_),
449  1,
450  C,
451  1,
452  1));
453  }
454  }
455  // Set the output
456  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
457  top_desc_,
458  GetCudnnTensorFormat(order_),
460  N,
461  C,
462  H_out,
463  W_out));
464  // Set the convolution descriptor
465  CAFFE_ENFORCE_EQ(
466  pad_t(),
467  pad_b(),
468  "The current padding scheme leads to unequal padding on the top and "
469  "bottom, which is not supported by cudnn.");
470  CAFFE_ENFORCE_EQ(
471  pad_l(),
472  pad_r(),
473  "The current padding scheme leads to unequal padding on the left "
474  "and right, which is not supported by cudnn.");
475 #if CUDNN_VERSION_MIN(6,0,0)
476  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
477  conv_desc_,
478  pad_t(),
479  pad_l(),
480  stride_h(),
481  stride_w(),
482  1,
483  1,
484  CUDNN_CROSS_CORRELATION,
486 #else
487  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
488  conv_desc_,
489  pad_t(),
490  pad_l(),
491  stride_h(),
492  stride_w(),
493  1,
494  1,
495  CUDNN_CROSS_CORRELATION));
496 #endif
497 #if CUDNN_VERSION_MIN(7, 0, 0)
498  // enable TensorCore math if desired
499  enable_tensor_core_ &= TensorCoreAvailable();
500  if (enable_tensor_core_) {
501  CUDNN_ENFORCE(
502  cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
503  }
504 #endif
505  if (force_algo_[ALGO_WGRAD] >= 0) {
506  bwd_filter_algo_ =
507  (cudnnConvolutionBwdFilterAlgo_t)force_algo_[ALGO_WGRAD];
508  } else if (deterministic_) {
509  algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
510  bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
511  } else if (exhaustive_search_) {
512  bwd_filter_algo_ =
513  filter_algo_cache_.getAlgorithm(X.dims(), filter.dims(), 0, [&]() {
514  LOG(INFO) << "CUDNN Convolution bwd: doing exhaustive search.";
515  // When we do an exhaustive search, we will ignore the workspace
516  // size
517  // limit and simply go for the fastest algorithm. If you happen to
518  // run
519  // out of memory later, you will be on your own...
520  int returned_algo_count;
521  // We clean up the current workspace memory so that the forward
522  // algorithm
523  // is free to allocate memory.
524  // Actually run the search.
525  std::array<
526  cudnnConvolutionBwdFilterAlgoPerf_t,
527  kNUM_CUDNN_BWD_FILTER_ALGS>
528  filter_perf_stat;
529 
530  cudnn_wrapper_.with_cudnn_state(
531  cudnn_state_, [&](CuDNNState* state) {
532  state->workspace().reset();
533  CUDNN_ENFORCE(cudnnFindConvolutionBackwardFilterAlgorithm(
534  state->cudnn_handle(),
535  top_desc_,
536  bottom_desc_,
537  conv_desc_,
538  filter_desc_,
539  kNUM_CUDNN_BWD_FILTER_ALGS,
540  &returned_algo_count,
541  filter_perf_stat.data()));
542  });
543  LogCuDNNPerfStats(filter_perf_stat, returned_algo_count);
544  return filter_perf_stat[0].algo;
545  });
546 
547  algo_ =
548  forward_algo_cache_.getAlgorithm(X.dims(), filter.dims(), 0, [&]() {
549  int returned_algo_count;
550  std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
551  fwd_perf_stat;
552  cudnn_wrapper_.with_cudnn_state(
553  cudnn_state_, [&](CuDNNState* state) {
554  state->workspace().reset();
555  CUDNN_ENFORCE(cudnnFindConvolutionForwardAlgorithm(
556  state->cudnn_handle(),
557  top_desc_,
558  filter_desc_,
559  conv_desc_,
560  bottom_desc_,
561  kNUM_CUDNN_BWD_DATA_ALGS,
562  &returned_algo_count,
563  fwd_perf_stat.data()));
564  });
565 
566  LogCuDNNPerfStats(fwd_perf_stat, returned_algo_count);
567  return fwd_perf_stat[0].algo;
568  });
569  } else {
570  // choose backward algorithm for filter
571  CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterAlgorithm(
572  cudnn_wrapper_.inline_cudnn_handle(),
573  top_desc_,
574  bottom_desc_,
575  conv_desc_,
576  filter_desc_,
577  CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
578  cudnn_ws_nbytes_limit_,
579  &bwd_filter_algo_));
580  // choose backward algo for data
581  CUDNN_ENFORCE(cudnnGetConvolutionForwardAlgorithm(
582  cudnn_wrapper_.inline_cudnn_handle(),
583  top_desc_,
584  filter_desc_,
585  conv_desc_,
586  bottom_desc_,
587  CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
588  cudnn_ws_nbytes_limit_,
589  &algo_));
590  }
591  // get workspace for backwards filter algorithm
592  size_t bwd_filter_ws_size, fwd_ws_size;
593  CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterWorkspaceSize(
594  cudnn_wrapper_.inline_cudnn_handle(),
595  top_desc_,
596  bottom_desc_,
597  conv_desc_,
598  filter_desc_,
599  bwd_filter_algo_,
600  &bwd_filter_ws_size));
601  // get workspace for backwards data algorithm
602  CUDNN_ENFORCE(cudnnGetConvolutionForwardWorkspaceSize(
603  cudnn_wrapper_.inline_cudnn_handle(),
604  top_desc_,
605  filter_desc_,
606  conv_desc_,
607  bottom_desc_,
608  algo_,
609  &fwd_ws_size));
610  cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, fwd_ws_size);
611 
612  VLOG(1) << "CuDNN bwd algorithm: " << bwd_filter_algo_ << ", " << algo_;
613  VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
614  }
615 
616  // Now, actually run the computation.
617  if (!no_bias_) {
618  auto* dbias = Output(BIAS_OR_INPUT_GRAD);
619  dbias->Resize(C);
620  CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
621  cudnn_wrapper_.inline_cudnn_handle(),
623  top_desc_,
624  dY.template data<T>(),
626  bias_desc_,
627  dbias->template mutable_data<T>()));
628  }
629 
630  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
631  CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
632  state->cudnn_handle(),
634  top_desc_,
635  dY.template data<T>(),
636  bottom_desc_,
637  X.template data<T>(),
638  conv_desc_,
639  bwd_filter_algo_,
640  state->workspace().get(cudnn_ws_nbytes_),
641  cudnn_ws_nbytes_,
643  filter_desc_,
644  dfilter->template mutable_data<T>()));
645 
646  if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
647  // Compute the gradient w.r.t. the input.
648  auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
649  dX->ResizeLike(X);
650  CUDNN_ENFORCE(cudnnConvolutionForward(
651  state->cudnn_handle(),
653  top_desc_,
654  dY.template data<T>(),
655  filter_desc_,
656  filter.template data<T>(),
657  conv_desc_,
658  algo_,
659  state->workspace().get(cudnn_ws_nbytes_),
660  cudnn_ws_nbytes_,
662  bottom_desc_,
663  dX->template mutable_data<T>()));
664  }
665  });
666  return true;
667 }
668 
669 REGISTER_CUDNN_OPERATOR(ConvTranspose, CudnnConvTransposeOp<float>);
670 REGISTER_CUDNN_OPERATOR(
671  ConvTransposeGradient,
673 
674 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:183
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
bool TensorCoreAvailable()
Return the availability of TensorCores for math.
Definition: common_gpu.cc:238
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:111