1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/types.h" 10 #if CUDNN_VERSION_MIN(7,0,0) 12 class CuDNNDropoutOp final :
public Operator<CUDAContext> {
14 USE_OPERATOR_FUNCTIONS(CUDAContext);
16 CuDNNDropoutOp(
const OperatorDef& operator_def, Workspace* ws)
17 : Operator<CUDAContext>(operator_def, ws),
18 cudnn_wrapper_(&context_),
19 ratio_(OperatorBase::GetSingleArgument<float>(
"ratio", 0.5)),
21 OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
22 states_initialized_(false),
23 random_seed_(operator_def.device_option().random_seed()) {
24 CAFFE_ENFORCE_GE(ratio_, 0);
25 CAFFE_ENFORCE_LT(ratio_, 1);
26 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
28 CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropout_desc_));
29 CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
30 cudnn_wrapper_.inline_cudnn_handle(),
31 reinterpret_cast<size_t*
>(&states_size_in_bytes_)));
34 scratch_blob_ = ws->CreateBlob(scratch_blob_name(operator_def.output(1)));
35 CAFFE_ENFORCE(scratch_blob_);
39 ~CuDNNDropoutOp() noexcept {
40 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
41 CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropout_desc_));
44 template <
typename T,
typename M>
47 bool RunOnDevice()
override;
49 static string scratch_blob_name(
string mask_blob_name) {
50 return "cudnn_dropout_scratch_" + mask_blob_name;
54 CuDNNWrapper cudnn_wrapper_;
55 cudnnTensorDescriptor_t data_desc_;
56 cudnnDropoutDescriptor_t dropout_desc_;
58 vector<TIndex> cudnn_input_dims_;
63 Blob* scratch_blob_ =
nullptr;
65 size_t states_size_in_bytes_, reserve_space_size_in_bytes_;
69 bool states_initialized_;
72 unsigned long long random_seed_;
75 class CuDNNDropoutGradientOp final :
public Operator<CUDAContext> {
77 USE_OPERATOR_FUNCTIONS(CUDAContext);
78 CuDNNDropoutGradientOp(
const OperatorDef& operator_def, Workspace* ws)
79 : Operator<CUDAContext>(operator_def, ws),
80 cudnn_wrapper_(&context_),
81 ratio_(OperatorBase::GetSingleArgument<float>(
"ratio", 0.5)),
83 OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
84 states_initialized_(false),
85 random_seed_(operator_def.device_option().random_seed()) {
86 CAFFE_ENFORCE_GE(ratio_, 0);
87 CAFFE_ENFORCE_LT(ratio_, 1);
88 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
90 CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropout_desc_));
91 CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
92 cudnn_wrapper_.inline_cudnn_handle(),
93 reinterpret_cast<size_t*
>(&states_size_in_bytes_)));
97 ws->GetBlob(CuDNNDropoutOp::scratch_blob_name(operator_def.input(1)));
98 CAFFE_ENFORCE(scratch_blob_);
101 ~CuDNNDropoutGradientOp() noexcept {
102 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
103 CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropout_desc_));
106 template <
typename T,
typename M>
107 bool DoRunWithType();
109 bool RunOnDevice()
override;
112 CuDNNWrapper cudnn_wrapper_;
113 cudnnTensorDescriptor_t data_desc_;
114 cudnnDropoutDescriptor_t dropout_desc_;
116 vector<TIndex> cudnn_input_dims_;
123 size_t states_size_in_bytes_, reserve_space_size_in_bytes_;
127 bool states_initialized_;
129 unsigned long long random_seed_;
132 template <
typename T,
typename M>
133 bool CuDNNDropoutOp::DoRunWithType() {
134 const auto& X = Input(0);
138 for (
auto dim : X.dims()) {
144 context_.Copy<T, CUDAContext, CUDAContext>(
145 X.size(), X.template data<T>(), Y->template mutable_data<T>());
149 auto* mask = Output(1);
151 if (X.dims() != cudnn_input_dims_ && !is_test_) {
152 CAFFE_ENFORCE(scratch_blob_);
155 cudnn_input_dims_ = X.dims();
156 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
159 cudnnTypeWrapper<T>::type,
166 CUDNN_ENFORCE(cudnnDropoutGetReserveSpaceSize(
167 data_desc_, &reserve_space_size_in_bytes_));
169 mask->Resize(reserve_space_size_in_bytes_);
170 states->Resize(states_size_in_bytes_);
172 if (!states_initialized_) {
175 uint8_t* states_data = states->mutable_data<uint8_t>();
178 std::lock_guard<std::mutex> lk(CUDAContext::mutex());
179 CUDNN_ENFORCE(cudnnSetDropoutDescriptor(
181 cudnn_wrapper_.inline_cudnn_handle(),
184 states_size_in_bytes_,
188 states_initialized_ =
true;
191 CUDNN_ENFORCE(cudnnDropoutForward(
192 cudnn_wrapper_.inline_cudnn_handle(),
195 X.template data<T>(),
197 Y->template mutable_data<T>(),
198 mask->mutable_data<uint8_t>(),
199 reserve_space_size_in_bytes_));
204 bool CuDNNDropoutOp::RunOnDevice() {
206 const auto& X = Input(0);
210 if (X.IsType<
float>()) {
211 return DoRunWithType<float, float>();
212 }
else if (X.IsType<float16>()) {
213 return DoRunWithType<float16, float>();
218 template <
typename T,
typename M>
219 bool CuDNNDropoutGradientOp::DoRunWithType() {
220 const auto& dY = Input(0);
221 const auto& mask = Input(1);
223 auto* dX = Output(0);
226 for (
auto dim : dY.dims()) {
230 if (!states_initialized_) {
234 std::lock_guard<std::mutex> lk(CUDAContext::mutex());
235 CUDNN_ENFORCE(cudnnRestoreDropoutDescriptor(
237 cudnn_wrapper_.inline_cudnn_handle(),
239 const_cast<uint8_t*
>(states.data<uint8_t>()),
240 states_size_in_bytes_,
244 states_initialized_ =
true;
247 if (dY.dims() != cudnn_input_dims_) {
248 cudnn_input_dims_ = dY.dims();
249 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
252 cudnnTypeWrapper<T>::type,
259 CUDNN_ENFORCE(cudnnDropoutGetReserveSpaceSize(
260 data_desc_, &reserve_space_size_in_bytes_));
265 void* mask_data =
const_cast<void*
>(mask.raw_data());
266 CUDNN_ENFORCE(cudnnDropoutBackward(
267 cudnn_wrapper_.inline_cudnn_handle(),
272 dX->template mutable_data<T>(),
274 reserve_space_size_in_bytes_));
278 bool CuDNNDropoutGradientOp::RunOnDevice() {
280 const auto& dY = Input(0);
281 auto* dX = Output(0);
285 if (dY.IsType<
float>()) {
286 return DoRunWithType<float, float>();
287 }
else if (dY.IsType<float16>()) {
288 return DoRunWithType<float16, float>();
294 REGISTER_CUDNN_OPERATOR(Dropout, CuDNNDropoutOp);
295 REGISTER_CUDNN_OPERATOR(DropoutGrad, CuDNNDropoutGradientOp);
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...