Caffe2 - C++ API
A deep learning, cross platform ML framework
h_softmax_op.h
1 #ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
2 #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/proto/hsm.pb.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, typename Context>
13 class HSoftmaxOpBase : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  HSoftmaxOpBase(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws) {
18  HierarchyProto hierarchy;
19  CAFFE_ENFORCE(hierarchy.ParseFromString(
20  OperatorBase::GetSingleArgument<string>("hierarchy", "")));
21  for (const auto& path : hierarchy.paths()) {
22  hierarchy_all_map_.emplace(path.word_id(), path);
23  }
24  }
25 
26  protected:
27  std::unordered_map<int, PathProto> hierarchy_all_map_;
28  Tensor<Context> scale_;
29  Tensor<Context> sum_multiplier_;
30  Tensor<Context> bias_multiplier_;
31  static constexpr T kLOG_THRESHOLD() {
32  return 1e-20f;
33  }
34  static std::unordered_map<int, PathProto> getHierarchyForLabels(
35  int M,
36  const int* labels,
37  const std::unordered_map<int, PathProto>& hierarchy_all_map) {
38  std::unordered_map<int, PathProto> hierarchy_map;
39  std::set<int> label_set = std::set<int>(labels, labels + M);
40  for (const auto& label : label_set) {
41  auto search = hierarchy_all_map.find(label);
42  CAFFE_ENFORCE(search != hierarchy_all_map.end(), "incorrect label.");
43  hierarchy_map.emplace(search->first, search->second);
44  }
45  return hierarchy_map;
46  }
47  int getIntermediateOutputSize(
48  const int* labels,
49  int M,
50  std::unordered_map<int, PathProto>& hierarchy) const {
51  int size = 0;
52  for (int label = 0; label < M; ++label) {
53  int word_id = labels[label];
54  const auto& path = hierarchy[word_id];
55  size += std::accumulate(
56  path.path_nodes().begin(),
57  path.path_nodes().end(),
58  0,
59  // Output of FC + Output of Softmax
60  [](int sz, PathNodeProto node) {
61  return sz + 2 * node.length();
62  });
63  }
64  return size;
65  }
66 };
67 
68 template <typename T, class Context>
69 class HSoftmaxOp : public HSoftmaxOpBase<T, Context> {
70  public:
71  USE_OPERATOR_CONTEXT_FUNCTIONS;
73 
74  bool RunOnDevice() override;
75 
76  protected:
77  float RunForwardSingle(
78  const float* X,
79  const float* W,
80  const float* b,
81  int target,
82  float* output,
83  const float* bias_multiplier,
84  int w_length,
85  int K,
86  int& output_offset);
87 };
88 
89 template <typename T, class Context>
90 class HSoftmaxGradientOp final : public HSoftmaxOpBase<T, Context> {
91  public:
92  USE_OPERATOR_CONTEXT_FUNCTIONS;
94  bool RunOnDevice() override;
95 
96  private:
97  void RunBackwardSingle(
98  const float* X,
99  const float* dY,
100  const float* W,
101  int target,
102  const float* int_output,
103  float* dX,
104  float* dW,
105  float* db,
106  float* dOutput,
107  int dim_in,
108  int w_length,
109  int& output_offset);
110 };
111 
112 template <typename T, class Context>
113 class HSoftmaxSearchOp final : public HSoftmaxOp<T, Context> {
114  public:
115  USE_OPERATOR_CONTEXT_FUNCTIONS;
116  HSoftmaxSearchOp(const OperatorDef& operator_def, Workspace* ws)
117  : HSoftmaxOp<T, Context>(operator_def, ws),
118  top_n_(OperatorBase::GetSingleArgument<int>("topN", 5)),
119  beam_(OperatorBase::GetSingleArgument<float>("beam", 0.01f)) {
120  CAFFE_ENFORCE(tree_.ParseFromString(
121  OperatorBase::GetSingleArgument<string>("tree", "")));
122  }
123  bool RunOnDevice() override;
124 
125  private:
126  int top_n_;
127  float beam_;
128  TreeProto tree_;
129  bool pruning(
130  const float* X,
131  int sample,
132  int K,
133  const float* W,
134  const float* b,
135  const NodeProto& src_node,
136  NodeProto& dst_node,
137  float parent_score,
138  float beam);
139  bool extractNodes(
140  const NodeProto& node,
141  std::vector<std::pair<string, float>>& info);
142 };
143 
144 template <typename T, class Context>
145 class HuffmanTreeHierarchyOp : public Operator<Context> {
146  public:
147  USE_OPERATOR_CONTEXT_FUNCTIONS;
148  HuffmanTreeHierarchyOp(const OperatorDef& operator_def, Workspace* ws)
149  : Operator<Context>(operator_def, ws),
150  num_classes_(OperatorBase::GetSingleArgument<int>("num_classes", -1)) {}
151  bool RunOnDevice() override;
152 
153  private:
154  // Internal huffman tree data.
155  struct Node {
156  Node(T l, int count)
157  : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {}
158  T label;
159  int count;
160  int left_ch_index;
161  int right_ch_index;
162  };
163 
164  struct NodeComparator {
165  bool operator()(const Node& node_a, const Node& node_b) {
166  return node_a.count > node_b.count;
167  }
168  };
169 
170  int num_classes_;
171 };
172 
173 } // namespace caffe2
174 
175 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...