Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_op.h
1 #ifndef CAFFE2_OPERATORS_SOFTMAX_OP_H_
2 #define CAFFE2_OPERATORS_SOFTMAX_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class SoftmaxOp final : public Operator<Context> {
13  public:
14  SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18  bool RunOnDevice() override;
19 
20  protected:
21  int axis_;
22  Tensor<Context> scale_;
23  Tensor<Context> rowmax_;
24  Tensor<Context> sum_multiplier_;
25 };
26 
27 template <typename T, class Context>
28 class SoftmaxGradientOp final : public Operator<Context> {
29  public:
30  SoftmaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
33  USE_OPERATOR_CONTEXT_FUNCTIONS;
34  bool RunOnDevice() override;
35 
36  protected:
37  int axis_;
38  Tensor<Context> scale_;
39  Tensor<Context> sum_multiplier_;
40 };
41 
42 } // namespace caffe2
43 
44 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...