1 #ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ 2 #define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/proto/hsm.pb.h" 8 #include "caffe2/utils/math.h" 12 template <
typename T,
typename Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 HierarchyProto hierarchy;
19 CAFFE_ENFORCE(hierarchy.ParseFromString(
20 OperatorBase::GetSingleArgument<string>(
"hierarchy",
"")));
21 for (
const auto& path : hierarchy.paths()) {
22 hierarchy_all_map_.emplace(path.word_id(), path);
27 std::unordered_map<int, PathProto> hierarchy_all_map_;
31 static constexpr T kLOG_THRESHOLD() {
34 static std::unordered_map<int, PathProto> getHierarchyForLabels(
37 const std::unordered_map<int, PathProto>& hierarchy_all_map) {
38 std::unordered_map<int, PathProto> hierarchy_map;
39 std::set<int> label_set = std::set<int>(labels, labels + M);
40 for (
const auto& label : label_set) {
41 auto search = hierarchy_all_map.find(label);
42 CAFFE_ENFORCE(search != hierarchy_all_map.end(),
"incorrect label.");
43 hierarchy_map.emplace(search->first, search->second);
47 int getIntermediateOutputSize(
50 std::unordered_map<int, PathProto>& hierarchy)
const {
52 for (
int label = 0; label < M; ++label) {
53 int word_id = labels[label];
54 const auto& path = hierarchy[word_id];
55 size += std::accumulate(
56 path.path_nodes().begin(),
57 path.path_nodes().end(),
60 [](
int sz, PathNodeProto node) {
61 return sz + 2 * node.length();
68 template <
typename T,
class Context>
71 USE_OPERATOR_CONTEXT_FUNCTIONS;
74 bool RunOnDevice()
override;
77 float RunForwardSingle(
83 const float* bias_multiplier,
89 template <
typename T,
class Context>
92 USE_OPERATOR_CONTEXT_FUNCTIONS;
94 bool RunOnDevice()
override;
97 void RunBackwardSingle(
102 const float* int_output,
112 template <
typename T,
class Context>
115 USE_OPERATOR_CONTEXT_FUNCTIONS;
118 top_n_(OperatorBase::GetSingleArgument<int>(
"topN", 5)),
119 beam_(OperatorBase::GetSingleArgument<float>(
"beam", 0.01f)) {
120 CAFFE_ENFORCE(tree_.ParseFromString(
121 OperatorBase::GetSingleArgument<string>(
"tree",
"")));
123 bool RunOnDevice()
override;
135 const NodeProto& src_node,
140 const NodeProto& node,
141 std::vector<std::pair<string, float>>& info);
144 template <
typename T,
class Context>
147 USE_OPERATOR_CONTEXT_FUNCTIONS;
150 num_classes_(OperatorBase::GetSingleArgument<int>(
"num_classes", -1)) {}
151 bool RunOnDevice()
override;
157 : label(l), count(count), left_ch_index(-1), right_ch_index(-1) {}
164 struct NodeComparator {
165 bool operator()(
const Node& node_a,
const Node& node_b) {
166 return node_a.count > node_b.count;
175 #endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...