Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_softmax_with_loss_op.h
1 #ifndef SPATIAL_SOFTMAX_WITH_LOSS_OP_H_
2 #define SPATIAL_SOFTMAX_WITH_LOSS_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class SpatialSoftmaxWithLossOp final : public Operator<Context> {
13  public:
14  SpatialSoftmaxWithLossOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
17  order_(StringToStorageOrder(
18  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
19  CAFFE_ENFORCE(scale_ >= 0);
20  CAFFE_ENFORCE_EQ(
21  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
22  }
23  USE_OPERATOR_CONTEXT_FUNCTIONS;
24 
25  bool RunOnDevice() override;
26 
27  protected:
28  float scale_;
29  StorageOrder order_;
30 
31  Tensor<Context> losses_; // Per example loss
32  Tensor<Context> rowmax_; // per example row max
33  Tensor<Context> weights_; // unignored weights
34  Tensor<Context> sum_multiplier_; // Vector of ones for summing via dot prod
35  Tensor<Context> total_weight_ptr_;
36  Tensor<Context> scratch_;
37 };
38 
39 template <typename T, class Context>
40 class SpatialSoftmaxWithLossGradientOp final : public Operator<Context> {
41  public:
42  SpatialSoftmaxWithLossGradientOp(const OperatorDef& def, Workspace* ws)
43  : Operator<Context>(def, ws),
44  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
45  order_(StringToStorageOrder(
46  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
47  only_loss_(OperatorBase::GetSingleArgument<bool>("only_loss", false)) {
48  CAFFE_ENFORCE(scale_ >= 0);
49  CAFFE_ENFORCE_EQ(
50  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
51  }
52  USE_OPERATOR_CONTEXT_FUNCTIONS;
53 
54  bool RunOnDevice() override;
55 
56  protected:
57  float scale_;
58  Tensor<Context> sum_multiplier_;
59  Tensor<Context> weights_; // unignored weights
60  Tensor<Context> total_weight_ptr_;
61  StorageOrder order_;
62  bool only_loss_;
63  Tensor<Context> scratch_;
64 };
65 
66 } // namespace caffe2
67 
68 #endif // SOFTMAX_WITH_LOSS_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...