Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_op_cudnn.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/core/cudnn_wrappers.h"
3 #include "caffe2/core/types.h"
4 #include "caffe2/operators/softmax_op.h"
5 
6 namespace caffe2 {
7 
8 namespace {
9 constexpr int NUM_DESCRIPTORS = 2;
10 constexpr int GRADIENT_NUM_DESCRIPTORS = 3;
11 constexpr int BOTTOM_DESC_ID = 0;
12 constexpr int TOP_DESC_ID = 1;
13 constexpr int TOP_GRADIENT_DESC_ID = 2;
14 } // namespace
15 
16 class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
17  public:
18  explicit CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
19  : Operator<CUDAContext>(def, ws),
20  cudnn_wrapper_(&context_),
21  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
22  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
23  }
24 
25  ~CuDNNSoftmaxOp() {
26  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
27  }
28 
29  template <typename T>
30  bool DoRunWithType() {
31  auto& X = Input(0);
32  auto* Y = Output(0);
33  const auto canonical_axis = X.canonical_axis_index(axis_);
34  const int N = X.size_to_dim(canonical_axis);
35  const int D = X.size_from_dim(canonical_axis);
36 
37  Y->ResizeLike(X);
38  if (dims_ != X.dims()) {
39  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
40  desc_,
41  GetCudnnTensorFormat(StorageOrder::NCHW),
43  N,
44  D,
45  1,
46  1));
47  dims_ = X.dims();
48  }
49  CUDNN_ENFORCE(cudnnSoftmaxForward(
50  cudnn_wrapper_.inline_cudnn_handle(),
51  CUDNN_SOFTMAX_ACCURATE,
52  CUDNN_SOFTMAX_MODE_INSTANCE,
54  desc_,
55  X.template data<T>(),
57  desc_,
58  Y->template mutable_data<T>()));
59  return true;
60  }
61 
62  bool RunOnDevice() override {
63  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
64  }
65 
66  protected:
67  CuDNNWrapper cudnn_wrapper_;
68  int axis_;
69  cudnnTensorDescriptor_t desc_;
70  vector<TIndex> dims_;
71 };
72 
73 
74 class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
75  public:
76  explicit CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
77  : Operator<CUDAContext>(def, ws),
78  cudnn_wrapper_(&context_),
79  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
80  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
81  }
82 
84  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
85  }
86 
87  template <typename T>
88  bool DoRunWithType() {
89  auto& Y = Input(0);
90  auto& dY = Input(1);
91  auto* dX = Output(0);
92  const auto canonical_axis = Y.canonical_axis_index(axis_);
93  const int N = Y.size_to_dim(canonical_axis);
94  const int D = Y.size_from_dim(canonical_axis);
95 
96  CHECK_EQ(Y.dims(), dY.dims());
97  dX->ResizeLike(Y);
98  if (dims_ != Y.dims()) {
99  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
100  desc_,
101  GetCudnnTensorFormat(StorageOrder::NCHW),
103  N,
104  D,
105  1,
106  1));
107  dims_ = Y.dims();
108  }
109  CUDNN_ENFORCE(cudnnSoftmaxBackward(
110  cudnn_wrapper_.inline_cudnn_handle(),
111  CUDNN_SOFTMAX_ACCURATE,
112  CUDNN_SOFTMAX_MODE_INSTANCE,
114  desc_,
115  Y.template data<T>(),
116  desc_,
117  dY.template data<T>(),
119  desc_,
120  dX->template mutable_data<T>()));
121  return true;
122  }
123 
124  bool RunOnDevice() override {
125  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
126  }
127 
128  protected:
129  CuDNNWrapper cudnn_wrapper_;
130  int axis_;
131  cudnnTensorDescriptor_t desc_;
132  vector<TIndex> dims_;
133 };
134 
135 namespace {
136 REGISTER_CUDNN_OPERATOR(Softmax, CuDNNSoftmaxOp);
137 REGISTER_CUDNN_OPERATOR(SoftmaxGradient, CuDNNSoftmaxGradientOp);
138 } // namespace
139 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:183
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:111