3 #ifndef CAFFE2_CORE_CUDNN_WRAPPERS_H_ 4 #define CAFFE2_CORE_CUDNN_WRAPPERS_H_ 6 #include "caffe2/core/common_cudnn.h" 7 #include "caffe2/core/context_gpu.h" 24 void*
get(
size_t nbytes) {
25 if (nbytes_ < nbytes) {
27 auto data_and_deleter = CUDAContext::New(nbytes);
28 data_ = {data_and_deleter.first, data_and_deleter.second};
31 CAFFE_ENFORCE_GE(nbytes_, nbytes);
41 std::unique_ptr<void, MemoryDeleter> data_{
nullptr, NoDelete};
51 explicit CuDNNState(
size_t gpu_id) : gpu_id_(gpu_id) {
53 CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_));
54 CUDA_ENFORCE(cudaEventCreate(&before_));
55 CUDA_ENFORCE(cudaEventCreate(&after_));
56 CUDA_ENFORCE(cudaStreamCreate(&stream_));
57 CUDNN_ENFORCE(cudnnSetStream(cudnn_handle_, stream_));
62 CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
63 CUDA_CHECK(cudaStreamDestroy(stream_));
64 CUDA_CHECK(cudaEventDestroy(after_));
65 CUDA_CHECK(cudaEventDestroy(before_));
68 cudnnHandle_t& cudnn_handle() {
77 void execute(cudaStream_t stream, F&& f) {
78 CUDA_ENFORCE(cudaEventRecord(before_, stream));
79 CUDA_ENFORCE(cudaStreamWaitEvent(stream_, before_, 0));
81 CUDA_ENFORCE(cudaEventRecord(after_, stream_));
82 CUDA_ENFORCE(cudaStreamWaitEvent(stream, after_, 0));
86 cudnnHandle_t cudnn_handle_{
nullptr};
87 cudaEvent_t before_{
nullptr};
88 cudaEvent_t after_{
nullptr};
89 cudaStream_t stream_{
nullptr};
117 return context_->cudnn_handle();
121 template <
typename F>
122 void with_cudnn_state(
size_t state_idx, F&& f) {
124 state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES,
"Invalid state_idx");
125 auto& sync_state = cudnn_states()[context_->cuda_gpu_id()][state_idx];
133 std::lock_guard<std::mutex> g(sync_state.mutex);
134 if (!sync_state.state.get()) {
135 sync_state.state.reset(
new CuDNNState(context_->cuda_gpu_id()));
137 CHECK_NOTNULL(sync_state.state.get())->execute(context_->cuda_stream(), f);
144 static constexpr
size_t CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES = 4;
148 std::unique_ptr<CuDNNState> state;
151 using PerGPUCuDNNStates = std::array<
152 std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
153 CAFFE2_COMPILE_TIME_MAX_GPUS>;
154 static PerGPUCuDNNStates& cudnn_states();
CuDNNWrapper(CUDAContext *context)
Creates a cudnn wrapper associated with a CUDAContext object.
CuDNNWorkspace is a wrapper around a raw cuda pointer that holds the cudnn scratch space...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread's cuda_stream.