1 #ifndef CAFFE2_CORE_COMMON_CUDNN_H_ 2 #define CAFFE2_CORE_COMMON_CUDNN_H_ 9 #include "caffe2/core/common.h" 10 #include "caffe2/core/context.h" 11 #include "caffe2/core/logging.h" 12 #include "caffe2/core/types.h" 13 #include "caffe2/proto/caffe2.pb.h" 16 CUDNN_VERSION >= 5000,
17 "Caffe2 requires cudnn version 5.0 or above.");
19 #if CUDNN_VERSION < 6000 20 #pragma message "CUDNN version under 6.0 is supported at best effort." 21 #pragma message "We strongly encourage you to move to 6.0 and above." 22 #pragma message "This message is intended to annoy you enough to update." 23 #endif // CUDNN_VERSION < 6000 25 #define CUDNN_VERSION_MIN(major, minor, patch) \ 26 (CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch))) 34 inline const char* cudnnGetErrorString(cudnnStatus_t status) {
36 case CUDNN_STATUS_SUCCESS:
37 return "CUDNN_STATUS_SUCCESS";
38 case CUDNN_STATUS_NOT_INITIALIZED:
39 return "CUDNN_STATUS_NOT_INITIALIZED";
40 case CUDNN_STATUS_ALLOC_FAILED:
41 return "CUDNN_STATUS_ALLOC_FAILED";
42 case CUDNN_STATUS_BAD_PARAM:
43 return "CUDNN_STATUS_BAD_PARAM";
44 case CUDNN_STATUS_INTERNAL_ERROR:
45 return "CUDNN_STATUS_INTERNAL_ERROR";
46 case CUDNN_STATUS_INVALID_VALUE:
47 return "CUDNN_STATUS_INVALID_VALUE";
48 case CUDNN_STATUS_ARCH_MISMATCH:
49 return "CUDNN_STATUS_ARCH_MISMATCH";
50 case CUDNN_STATUS_MAPPING_ERROR:
51 return "CUDNN_STATUS_MAPPING_ERROR";
52 case CUDNN_STATUS_EXECUTION_FAILED:
53 return "CUDNN_STATUS_EXECUTION_FAILED";
54 case CUDNN_STATUS_NOT_SUPPORTED:
55 return "CUDNN_STATUS_NOT_SUPPORTED";
56 case CUDNN_STATUS_LICENSE_ERROR:
57 return "CUDNN_STATUS_LICENSE_ERROR";
59 return "Unknown cudnn error number";
66 #define CUDNN_ENFORCE(condition) \ 68 cudnnStatus_t status = condition; \ 71 CUDNN_STATUS_SUCCESS, \ 77 ::caffe2::internal::cudnnGetErrorString(status)); \ 79 #define CUDNN_CHECK(condition) \ 81 cudnnStatus_t status = condition; \ 82 CHECK(status == CUDNN_STATUS_SUCCESS) \ 83 << ::caffe2::internal::cudnnGetErrorString(status); \ 87 inline size_t cudnnCompiledVersion() {
91 inline size_t cudnnRuntimeVersion() {
92 return cudnnGetVersion();
96 inline void CheckCuDNNVersions() {
99 bool version_match = cudnnCompiledVersion() == cudnnRuntimeVersion();
100 CAFFE_ENFORCE(version_match,
101 "cuDNN compiled (", cudnnCompiledVersion(),
") and " 102 "runtime (", cudnnRuntimeVersion(),
") versions mismatch");
110 template <
typename T>
116 static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
117 typedef const float ScalingParamType;
118 typedef float BNParamType;
119 static ScalingParamType* kOne() {
120 static ScalingParamType v = 1.0;
123 static const ScalingParamType* kZero() {
124 static ScalingParamType v = 0.0;
129 #if CUDNN_VERSION_MIN(6, 0, 0) 133 static const cudnnDataType_t type = CUDNN_DATA_INT32;
134 typedef const int ScalingParamType;
135 typedef int BNParamType;
136 static ScalingParamType* kOne() {
137 static ScalingParamType v = 1;
140 static const ScalingParamType* kZero() {
141 static ScalingParamType v = 0;
145 #endif // CUDNN_VERSION_MIN(6, 0, 0) 150 static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
151 typedef const double ScalingParamType;
152 typedef double BNParamType;
153 static ScalingParamType* kOne() {
154 static ScalingParamType v = 1.0;
157 static ScalingParamType* kZero() {
158 static ScalingParamType v = 0.0;
166 static const cudnnDataType_t type = CUDNN_DATA_HALF;
167 typedef const float ScalingParamType;
168 typedef float BNParamType;
169 static ScalingParamType* kOne() {
170 static ScalingParamType v = 1.0;
173 static ScalingParamType* kZero() {
174 static ScalingParamType v = 0.0;
185 case StorageOrder::NHWC:
186 return CUDNN_TENSOR_NHWC;
187 case StorageOrder::NCHW:
188 return CUDNN_TENSOR_NCHW;
190 LOG(FATAL) <<
"Unknown cudnn equivalent for order: " << order;
193 return CUDNN_TENSOR_NCHW;
204 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
207 CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
210 inline cudnnTensorDescriptor_t Descriptor(
211 const cudnnTensorFormat_t format,
212 const cudnnDataType_t type,
213 const vector<int>& dims,
215 if (type_ == type && format_ == format && dims_ == dims) {
222 dims.size(), 4,
"Currently only 4-dimensional descriptor supported.");
226 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
231 (format == CUDNN_TENSOR_NCHW ? dims_[1] : dims_[3]),
232 (format == CUDNN_TENSOR_NCHW ? dims_[2] : dims_[1]),
233 (format == CUDNN_TENSOR_NCHW ? dims_[3] : dims_[2])));
239 template <
typename T>
240 inline cudnnTensorDescriptor_t Descriptor(
241 const StorageOrder& order,
242 const vector<int>& dims) {
248 cudnnTensorDescriptor_t desc_;
249 cudnnTensorFormat_t format_;
250 cudnnDataType_t type_;
258 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&desc_));
261 CUDNN_CHECK(cudnnDestroyFilterDescriptor(desc_));
264 inline cudnnFilterDescriptor_t Descriptor(
265 const StorageOrder& order,
266 const cudnnDataType_t type,
267 const vector<int>& dims,
269 if (type_ == type && order_ == order && dims_ == dims) {
276 dims.size(), 4,
"Currently only 4-dimensional descriptor supported.");
280 CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
286 (order == StorageOrder::NCHW ? dims_[1] : dims_[3]),
287 (order == StorageOrder::NCHW ? dims_[2] : dims_[1]),
288 (order == StorageOrder::NCHW ? dims_[3] : dims_[2])));
294 template <
typename T>
295 inline cudnnFilterDescriptor_t Descriptor(
296 const StorageOrder& order,
297 const vector<int>& dims) {
302 cudnnFilterDescriptor_t desc_;
304 cudnnDataType_t type_;
312 #endif // CAFFE2_CORE_COMMON_CUDNN_H_ cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
cudnnTensorDescWrapper is the placeholder that wraps around a cudnnTensorDescriptor_t, allowing us to do descriptor change as-needed during runtime.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...