1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/types.h" 4 #include "caffe2/operators/softmax_op.h" 9 constexpr
int NUM_DESCRIPTORS = 2;
10 constexpr
int GRADIENT_NUM_DESCRIPTORS = 3;
11 constexpr
int BOTTOM_DESC_ID = 0;
12 constexpr
int TOP_DESC_ID = 1;
13 constexpr
int TOP_GRADIENT_DESC_ID = 2;
20 cudnn_wrapper_(&context_),
21 axis_(OperatorBase::GetSingleArgument<int>(
"axis", 1)) {
22 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
26 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
30 bool DoRunWithType() {
33 const auto canonical_axis = X.canonical_axis_index(axis_);
34 const int N = X.size_to_dim(canonical_axis);
35 const int D = X.size_from_dim(canonical_axis);
38 if (dims_ != X.dims()) {
39 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
49 CUDNN_ENFORCE(cudnnSoftmaxForward(
51 CUDNN_SOFTMAX_ACCURATE,
52 CUDNN_SOFTMAX_MODE_INSTANCE,
58 Y->template mutable_data<T>()));
62 bool RunOnDevice()
override {
69 cudnnTensorDescriptor_t desc_;
78 cudnn_wrapper_(&context_),
79 axis_(OperatorBase::GetSingleArgument<int>(
"axis", 1)) {
80 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
84 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
88 bool DoRunWithType() {
92 const auto canonical_axis = Y.canonical_axis_index(axis_);
93 const int N = Y.size_to_dim(canonical_axis);
94 const int D = Y.size_from_dim(canonical_axis);
96 CHECK_EQ(Y.dims(), dY.dims());
98 if (dims_ != Y.dims()) {
99 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
109 CUDNN_ENFORCE(cudnnSoftmaxBackward(
110 cudnn_wrapper_.inline_cudnn_handle(),
111 CUDNN_SOFTMAX_ACCURATE,
112 CUDNN_SOFTMAX_MODE_INSTANCE,
115 Y.template data<T>(),
117 dY.template data<T>(),
120 dX->template mutable_data<T>()));
124 bool RunOnDevice()
override {
131 cudnnTensorDescriptor_t desc_;
132 vector<TIndex> dims_;
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread's cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...