Caffe2 - C++ API
A deep learning, cross platform ML framework
h_softmax_op.cc
1 #include "caffe2/operators/h_softmax_op.h"
2 
3 #include <queue>
4 #include <stack>
5 
6 namespace caffe2 {
7 
8 template <>
9 float HSoftmaxOp<float, CPUContext>::RunForwardSingle(const float* X,
10  const float* W, const float* b, int target, float* int_output,
11  const float* bias_multiplier, int dim_out, int dim_in,
12  int& int_output_offset) {
13 
14  // W * x
15  float* fc_output_data = int_output + int_output_offset;
16 
17  math::Gemm<float, CPUContext>(CblasNoTrans, CblasTrans, 1, dim_out, dim_in, 1,
18  X, W, 0, fc_output_data, &context_);
19  math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, 1,
20  b, bias_multiplier, 1, fc_output_data, &context_);
21 
22  int_output_offset += dim_out;
23 
24  //Softmax
25  float* softmax_output_data = int_output + int_output_offset;
26 
27  if (scale_.size() != 1) {
28  scale_.Resize(1);
29  }
30  if (sum_multiplier_.size() != dim_out) {
31  sum_multiplier_.Resize(dim_out);
32  math::Set<float, CPUContext>(dim_out, 1.f,
33  sum_multiplier_.mutable_data<float>(), &context_);
34  }
35  math::RowwiseMax<float, CPUContext>(1, dim_out, fc_output_data,
36  scale_.mutable_data<float>(), &context_);
37 
38  // Put the intermediate result X - max(X) into Y
39  context_.template Copy<float, CPUContext, CPUContext>(dim_out, fc_output_data,
40  softmax_output_data);
41  // Subtract the scale
42  math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, -1,
43  sum_multiplier_.data<float>(), scale_.data<float>(), 1, softmax_output_data,
44  &context_);
45 
46  // Exponentiation
47  math::Exp<float, CPUContext>(dim_out, softmax_output_data,
48  softmax_output_data, &context_);
49  math::Gemv<float, CPUContext>(CblasNoTrans, 1, dim_out, 1,
50  softmax_output_data, sum_multiplier_.data<float>(), 0,
51  scale_.mutable_data<float>(), &context_);
52 
53  // Do division
54  const float scale = *scale_.data<float>();
55  for (int j = 0; j < dim_out; ++j) {
56  softmax_output_data[j] /= scale;
57  }
58 
59  int_output_offset += dim_out;
60 
61  if (target < 0) {
62  return -1;
63  }
64  //Return cross entropy loss
65  return -log(std::max(softmax_output_data[target], kLOG_THRESHOLD()));
66 }
67 
68 // Implementation for the CPU context.
69 template <>
70 bool HSoftmaxOp<float, CPUContext>::RunOnDevice() {
71  auto& X = Input(0);
72  const auto& W = Input(1);
73  const auto& b = Input(2);
74  auto& label = Input(3);
75  auto* Y = Output(0);
76  auto* intermediate_output = Output(1);
77 
78  // Batch size
79  int M = X.ndim() > 1 ? X.dim32(0) : 1;
80  // Input feature dimension
81  int K = X.size() / M;
82  CAFFE_ENFORCE_GE(W.ndim(), 2); // N*K
83  CAFFE_ENFORCE_EQ(b.ndim(), 1); // N
84  CAFFE_ENFORCE_EQ(K, W.size() / (W.dim32(0)));
85  // Sum of output dimensions of all hierarchy nodes
86  int N = W.dim32(0);
87  CAFFE_ENFORCE_EQ(N, b.dim32(0));
88  Y->Resize(M);
89  auto* Ydata = Y->mutable_data<float>();
90  math::Set<float, CPUContext>(M, 0.f, Ydata, &context_);
91  const auto* labeldata = label.data<int>();
92 
93  auto hierarchy = getHierarchyForLabels(M, labeldata, hierarchy_all_map_);
94  int int_output_size = getIntermediateOutputSize(labeldata, M, hierarchy);
95  intermediate_output->Resize(int_output_size);
96  float * int_output_data = intermediate_output->mutable_data<float>();
97  int int_output_offset = 0;
98 
99  if (bias_multiplier_.size() != M) {
100  bias_multiplier_.Resize(M);
101  math::Set<float, CPUContext>(M, static_cast<float>(1),
102  bias_multiplier_.mutable_data<float>(), &context_);
103  }
104 
105  for (int sample = 0; sample < M; ++sample) {
106  int word_id = labeldata[sample];
107  const PathProto& path = hierarchy[word_id];
108  for (const PathNodeProto& node : path.path_nodes()) {
109  //Offset of node's weight matrix in W
110  int w_offset = node.index();
111  //Number of output dimensions in node's weight matrix
112  int w_length = node.length();
113  int target = node.target();
114  //Adding log probabilities
115  Ydata[sample] += RunForwardSingle(X.data<float>() + sample*K,
116  W.data<float>() + w_offset*K, b.data<float>() + w_offset, target,
117  int_output_data, bias_multiplier_.data<float>()+sample, w_length, K,
118  int_output_offset);
119  }
120  }
121  return true;
122 }
123 
124 template <>
125 void HSoftmaxGradientOp<float, CPUContext>::RunBackwardSingle(const float* X,
126  const float* dY, const float* W, int target,
127  const float* int_output, float* dX, float* dW, float* db, float* dint_output,
128  int dim_in, int dim_out, int& int_output_offset) {
129 
130  //Cross entropy
131  // dX_entropy is the dX for the cross entropy layer
132  float* dX_entropy = dint_output + int_output_offset - dim_out;
133  // X_entropy is the X for the cross entropy layer and Y for the softmax layer
134  const float* X_entropy = int_output + int_output_offset - dim_out;
135 
136  math::Set<float, CPUContext>(dim_out, 0.f, dX_entropy, &context_);
137  dX_entropy[target] = - (*dY) / std::max(X_entropy[target], kLOG_THRESHOLD());
138 
139  int_output_offset -= dim_out;
140 
141  //Softmax
142  if (scale_.size() != 1) {
143  scale_.Resize(1);
144  }
145  float* scaledata = scale_.mutable_data<float>();
146 
147  if (sum_multiplier_.size() != dim_out) {
148  sum_multiplier_.Resize(dim_out);
149  math::Set<float, CPUContext>(dim_out, 1.f,
150  sum_multiplier_.mutable_data<float>(), &context_);
151  }
152 
153  float* dX_softmax = dint_output + int_output_offset - dim_out;
154  context_.Copy<float, CPUContext, CPUContext>(dim_out, dX_entropy, dX_softmax);
155 
156  math::Dot<float, CPUContext>(dim_out, X_entropy, dX_entropy, scaledata,
157  &context_);
158  math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, -1,
159  sum_multiplier_.data<float>(), scaledata , 1, dX_softmax, &context_);
160  math::Mul<float, CPUContext>(dim_out, dX_softmax, X_entropy, dX_softmax,
161  &context_);
162 
163  int_output_offset -= dim_out;
164 
165  //FC
166  if (bias_multiplier_.size() != 1) {
167  // If the helper bias multiplier has not been created, reshape and fill
168  // it with 1
169  bias_multiplier_.Resize(1);
170  math::Set<float, CPUContext>(1, static_cast<float>(1),
171  bias_multiplier_.template mutable_data<float>(), &context_);
172  }
173 
174  // Compute dW and add incrementally
175  // dW = dW + dX_softmax'*X
176  math::Gemm<float, CPUContext>(CblasTrans, CblasNoTrans, dim_out, dim_in, 1, 1,
177  dX_softmax, X, 1, dW, &context_);
178 
179  // Compute dB and add incrementally
180  // db = db + dX_softmax*bias_multiplier_
181  math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, 1, dX_softmax,
182  bias_multiplier_.template data<float>(), 1, db, &context_);
183 
184  // Compute dX and add incrementally
185  // dX = dX + W'dX_softmax
186  math::Gemv<float, CPUContext>(CblasTrans, dim_out, dim_in,
187  1, W, dX_softmax, 1, dX, &context_);
188 }
189 
190 // Implementation for the CPU context.
191 template <>
192 bool HSoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
193  auto& X = Input(0);
194  const auto& W = Input(1);
195  const auto& b = Input(2);
196  auto& label = Input(3);
197  auto& intermediate_output = Input(4);
198  auto& dY = Input(5);
199  auto* dX = Output(0);
200  auto* dW = Output(1);
201  auto* db = Output(2);
202  auto* dX_intermediate_output = Output(3);
203  dX->ResizeLike(X);
204  dW->ResizeLike(W);
205  db->ResizeLike(b);
206  dX_intermediate_output->ResizeLike(intermediate_output);
207 
208  float* dX_data = dX->mutable_data<float>();
209  float* dW_data = dW->mutable_data<float>();
210  float* db_data = db->mutable_data<float>();
211  float* dOutput_data = dX_intermediate_output->mutable_data<float>();
212 
213  math::Set<float, CPUContext>(X.size(), 0.f, dX_data, &context_);
214  math::Set<float, CPUContext>(W.size(), 0.f, dW_data, &context_);
215  math::Set<float, CPUContext>(b.size(), 0.f, db_data, &context_);
216  math::Set<float, CPUContext>(intermediate_output.size(), 0.f, dOutput_data,
217  &context_);
218 
219  // Batch size
220  int M = X.ndim() > 1 ? X.dim32(0) : 1;
221  // Input feature dimension
222  int K = X.size() / M;
223  const auto* labeldata = label.data<int>();
224 
225  auto hierarchy = getHierarchyForLabels(M, labeldata, hierarchy_all_map_);
226  int output_offset = getIntermediateOutputSize(labeldata, M, hierarchy);
227 
228  //Traverse backward to access intermediate_output generated by HSoftmaxOp
229  // sequentially in reverse order
230  for (int sample = M-1; sample >= 0; sample--) {
231  int word_id = labeldata[sample];
232  PathProto path = hierarchy[word_id];
233  for (auto node = path.path_nodes().rbegin();
234  node != path.path_nodes().rend(); node++) {
235  int w_offset = node->index();
236  int w_length = node->length();
237  int target = node->target();
238  RunBackwardSingle(X.data<float>() + sample*K, dY.data<float>() + sample,
239  W.data<float>() + w_offset*K, target, intermediate_output.data<float>(),
240  dX_data + sample*K, dW_data + w_offset*K, db_data + w_offset,
241  dOutput_data, K, w_length, output_offset);
242  }
243  }
244  return true;
245 }
246 
247 // Implementation for the CPU context.
248 template <>
249 bool HSoftmaxSearchOp<float, CPUContext>::pruning(
250  const float* X,
251  int sample,
252  int K,
253  const float* W,
254  const float* b,
255  const NodeProto& src_node,
256  NodeProto& dst_node,
257  float parent_score,
258  float beam) {
259  int w_length = src_node.children_size() + src_node.word_ids_size();
260  Tensor<CPUContext> intermediate_data;
261  intermediate_data.Resize(2 * w_length);
262  float* int_output_data = intermediate_data.template mutable_data<float>();
263  int int_output_offset = 0;
264  int w_offset = src_node.offset();
265 
266  RunForwardSingle(
267  X + K * sample,
268  W + w_offset * K,
269  b + w_offset,
270  -1,
271  int_output_data,
272  bias_multiplier_.template data<float>() + sample,
273  w_length,
274  K,
275  int_output_offset);
276 
277  float* softmax_output_data = int_output_data + w_length;
278  // real probabilities
279  for (int i = 0; i < w_length; i++) {
280  softmax_output_data[i] =
281  -log(std::max(softmax_output_data[i], kLOG_THRESHOLD())) + parent_score;
282  }
283  for (int i = 0; i < src_node.children_size(); i++) {
284  if (softmax_output_data[i] < parent_score + beam) {
285  dst_node.add_children();
286  int idx = dst_node.children_size() - 1;
287  CAFFE_ENFORCE(
288  src_node.children(i).has_offset(),
289  "HSM Search require the field offset in NodeProte");
290  dst_node.mutable_children(idx)->set_offset(src_node.children(i).offset());
291  CAFFE_ENFORCE(
292  src_node.children(i).has_name(),
293  "HSM Search require the field name in NodeProte");
294  dst_node.mutable_children(idx)->set_name(src_node.children(i).name());
295  dst_node.add_scores(softmax_output_data[i]);
296  pruning(
297  X,
298  sample,
299  K,
300  W,
301  b,
302  src_node.children(i),
303  *dst_node.mutable_children(idx),
304  softmax_output_data[i],
305  beam);
306  }
307  }
308 
309  for (int i = src_node.children_size(); i < w_length; i++) {
310  if (softmax_output_data[i] < parent_score + beam) {
311  dst_node.add_word_ids(src_node.word_ids(i - src_node.children_size()));
312  dst_node.add_scores(softmax_output_data[i]);
313  }
314  }
315 
316  return true;
317 }
318 
319 template <>
320 bool HSoftmaxSearchOp<float, CPUContext>::extractNodes(
321  const NodeProto& node,
322  std::vector<std::pair<string, float>>& info) {
323  int i = 0;
324 
325  for (const auto& n : node.children()) {
326  info.emplace_back(std::make_pair(n.name(), node.scores(i++)));
327  }
328  for (const int n : node.word_ids()) {
329  info.emplace_back(std::make_pair(caffe2::to_string(n), node.scores(i++)));
330  }
331 
332  for (const auto& n : node.children()) {
333  extractNodes(n, info);
334  }
335  return true;
336 }
337 
338 // Implementation for the CPU context.
339 template <>
340 bool HSoftmaxSearchOp<float, CPUContext>::RunOnDevice() {
341  auto& X = Input(0);
342  const auto& W = Input(1);
343  const auto& b = Input(2);
344  auto* Y_names = Output(0);
345  auto* Y_scores = Output(1);
346  // Batch size
347  int M = X.ndim() > 1 ? X.dim32(0) : 1;
348  // Input feature dimension
349  int K = X.size() / M;
350  CAFFE_ENFORCE(W.ndim() == 2, "Weight must be a matrix."); // N*K
351  CAFFE_ENFORCE(b.ndim() == 1, "Bias must be a vector."); // N
352  CAFFE_ENFORCE(K == W.size() / (W.dim32(0)), "feature dimension mismatch.");
353  // Sum of output dimensions of all hierarchy nodes
354  int N = W.dim32(0);
355  CAFFE_ENFORCE(N == b.dim32(0), "mismatch between Weight and Bias.");
356  Y_names->Resize(M, top_n_);
357  Y_scores->Resize(M, top_n_);
358 
359  if (bias_multiplier_.size() != M) {
360  bias_multiplier_.Resize(M);
361  math::Set<float, CPUContext>(
362  M,
363  static_cast<float>(1),
364  bias_multiplier_.mutable_data<float>(),
365  &context_);
366  }
367 
368  for (int sample = 0; sample < M; ++sample) {
369  CAFFE_ENFORCE(
370  tree_.root_node().has_offset(),
371  "HSM Search require the field offset in NodeProte");
372  CAFFE_ENFORCE(
373  tree_.root_node().has_name(),
374  "HSM Search require the field name in NodeProte");
375 
376  NodeProto dst_node;
377  dst_node.set_offset(tree_.root_node().offset());
378  dst_node.set_name(tree_.root_node().name());
379 
380  pruning(
381  X.data<float>(),
382  sample,
383  K,
384  W.data<float>(),
385  b.data<float>(),
386  tree_.root_node(),
387  dst_node,
388  0,
389  beam_);
390 
391  std::vector<std::pair<string, float>> info;
392  extractNodes(dst_node, info);
393  // saving the results for each sample.
394  std::partial_sort(
395  info.begin(),
396  info.begin() + (top_n_ < info.size() ? top_n_ : info.size() - 1),
397  info.end(),
398  [&](std::pair<string, float> a, std::pair<string, float> b) {
399  return a.second < b.second;
400  });
401  auto* y_name_data = Y_names->mutable_data<string>() + sample * top_n_;
402  auto* y_score_data = Y_scores->mutable_data<float>() + sample * top_n_;
403  for (int i = 0; i < top_n_; i++) {
404  if (i < info.size()) {
405  y_name_data[i] = info[i].first;
406  y_score_data[i] = info[i].second;
407  } else {
408  y_score_data[i] = 0;
409  }
410  }
411  }
412 
413  return true;
414 }
415 
416 template <typename T, class Context>
417 bool HuffmanTreeHierarchyOp<T, Context>::RunOnDevice() {
418  const auto& Y = Input(0);
419  auto treeOutput = Output(0);
420  CAFFE_ENFORCE_EQ(Y.ndim(), 1, "Input labels must be a vector.");
421  const auto y_data = Y.template data<T>();
422  treeOutput->Resize(1);
423  std::vector<int> labelCounts;
424  labelCounts.resize(num_classes_, 0);
425  for (int i = 0; i < Y.dim32(0); ++i) {
426  // Labels are in range [0, num_classes]
427  const int label_index = y_data[i];
428  CAFFE_ENFORCE_LT(
429  label_index,
430  num_classes_,
431  "Found an input label ",
432  label_index,
433  " not in range [",
434  0,
435  ",",
436  num_classes_,
437  "]");
438  labelCounts[label_index]++;
439  }
440 
441  std::priority_queue<Node, std::vector<Node>, NodeComparator> nodes;
442  std::vector<Node> huffmanTree;
443  std::vector<int> labelIndices;
444  labelIndices.resize(num_classes_);
445 
446  int current_node_index = 0;
447  for (int i = 0; i < num_classes_; ++i) {
448  Node node(i, labelCounts[i]);
449  nodes.push(node);
450  }
451 
452  // Extract node with minimum count and insert it in the tree array.
453  auto get_next_node = [&nodes, &huffmanTree, &labelIndices]() {
454  auto node = nodes.top();
455  int node_index = huffmanTree.size();
456  if (node.label != -1) {
457  labelIndices[node.label] = node_index;
458  }
459  nodes.pop();
460  huffmanTree.push_back(node);
461  return std::pair<int, Node>(node_index, node);
462  };
463 
464  // Merge two nodes and insert the results in the queue.
465  auto merge_nodes = [&nodes](
466  const std::pair<int, Node>& node_l, const std::pair<int, Node>& node_r) {
467  Node node(-1, node_l.second.count + node_r.second.count);
468  node.left_ch_index = node_l.first;
469  node.right_ch_index = node_r.first;
470  nodes.push(node);
471  };
472 
473  // Main loop for buttom up huffman tree construction.
474  while (!nodes.empty()) {
475  auto lNode = get_next_node();
476  if (!nodes.empty()) {
477  auto rNode = get_next_node();
478  merge_nodes(lNode, rNode);
479  }
480  }
481 
482  auto is_leaf_node = [&huffmanTree](const int node_index) {
483  return huffmanTree[node_index].left_ch_index == -1 &&
484  huffmanTree[node_index].right_ch_index == -1;
485  };
486 
487  auto get_node_label = [&huffmanTree](const int node_index) {
488  return huffmanTree[node_index].label;
489  };
490 
491  // Build huffman tree.
492  int current_offset = 0;
493  std::function<void(int, NodeProto*)> build_tree = [&](
494  const int node_index, NodeProto* node) {
495  if (is_leaf_node(node_index) || node_index == -1) {
496  return;
497  }
498  const int left_ch_index = huffmanTree[node_index].left_ch_index;
499  const int right_ch_index = huffmanTree[node_index].right_ch_index;
500  if (left_ch_index != -1) {
501  if (is_leaf_node(left_ch_index)) {
502  node->add_word_ids(get_node_label(left_ch_index));
503  } else {
504  auto* ch_node = node->add_children();
505  ch_node->set_offset(current_offset);
506  current_offset += 2;
507  build_tree(left_ch_index, ch_node);
508  }
509  }
510  if (right_ch_index != -1) {
511  if (is_leaf_node(right_ch_index)) {
512  node->add_word_ids(get_node_label(right_ch_index));
513  current_offset++;
514  } else {
515  auto* ch_node = node->add_children();
516  ch_node->set_offset(current_offset);
517  current_offset += 2;
518  build_tree(right_ch_index, ch_node);
519  }
520  }
521  };
522 
523  // The last element inserted in the tree is the root.
524  const int rootNodeIndex = huffmanTree.size() - 1;
525  NodeProto rootNode;
526  rootNode.set_offset(current_offset);
527  current_offset += 2;
528  build_tree(rootNodeIndex, &rootNode);
529  TreeProto treeProto;
530  *treeProto.mutable_root_node() = rootNode;
531 
532  treeProto.SerializeToString(treeOutput->template mutable_data<string>());
533  return true;
534 }
535 
536 namespace {
537 REGISTER_CPU_OPERATOR(HSoftmax, HSoftmaxOp<float, CPUContext>);
538 REGISTER_CPU_OPERATOR(HSoftmaxGradient,
539  HSoftmaxGradientOp<float, CPUContext>);
540 REGISTER_CPU_OPERATOR(HSoftmaxSearch, HSoftmaxSearchOp<float, CPUContext>);
541 REGISTER_CPU_OPERATOR(
542  HuffmanTreeHierarchy,
543  HuffmanTreeHierarchyOp<int64_t, CPUContext>);
544 
545 OPERATOR_SCHEMA(HSoftmax)
546  .NumInputs(4)
547  .NumOutputs(2)
548  .SetDoc(R"DOC(
549 Hierarchical softmax is an operator which approximates the softmax operator
550 while giving significant training speed gains and reasonably comparable
551 performance. In this operator, instead of calculating the probabilities of all
552 the classes, we calculate the probability of each step in the path from root to
553 the target word in the hierarchy.
554 
555 The operator takes a 2-D tensor (Tensor<float>) containing a batch of layers, a
556 set of parameters represented by the weight matrix and bias terms, and a 1-D
557 tensor (Tensor<int>) holding labels, or the indices of the target class. The
558 hierarchy has to be specified as an argument to the operator.
559 
560 The operator returns a 1-D tensor holding the computed log probability of the
561 target class and a 2-D tensor of intermediate outputs (from the weight matrix
562 and softmax from each step in the path from root to target class) which will be
563 used by the gradient operator to compute gradients for all samples in the batch.
564 )DOC")
565  .Arg("hierarchy", "Serialized HierarchyProto string containing list of "
566  "vocabulary words and their paths from root of hierarchy to the leaf")
567  .Input(0, "X", "Input data from previous layer")
568  .Input(1, "W", "2D blob containing 'stacked' fully connected weight "
569  "matrices. Each node in the hierarchy contributes one FC weight matrix if "
570  "it has children nodes. Dimension is N*D, D is input dimension of data (X), "
571  "N is sum of all output dimensions, or total number of nodes (excl root)")
572  .Input(2, "b", "1D blob with N parameters")
573  .Input(3, "labels", "int word_id of the target word")
574  .Output(0, "Y", "1-D of log probability outputs, one per sample")
575  .Output(1, "intermediate_output", "Extra blob to store the intermediate "
576  "FC and softmax outputs for each node in the hierarchical path of a word. "
577  "The outputs from samples are stored in consecutive blocks in the forward "
578  "pass and are used in reverse order in the backward gradientOp pass");
579 
580 OPERATOR_SCHEMA(HSoftmaxGradient).NumInputs(6).NumOutputs(4);
581 
582 class GetHSoftmaxGradient : public GradientMakerBase {
583  using GradientMakerBase::GradientMakerBase;
584  vector<OperatorDef> GetGradientDefs() override {
585  return SingleGradientDef(
586  "HSoftmaxGradient", "",
587  //X, W, b, label, intermediate output, dY
588  vector<string>{I(0), I(1), I(2), I(3), O(1), GO(0)},
589  //dX, dW, db, dintermediate_output
590  vector<string>{GI(0), GI(1), GI(2), GO(1)});
591  }
592 };
593 REGISTER_GRADIENT(HSoftmax, GetHSoftmaxGradient);
594 
595 OPERATOR_SCHEMA(HSoftmaxSearch)
596  .NumInputs(3)
597  .NumOutputs(2)
598  .SetDoc(R"DOC(
599 HSoftmaxSearch is an operator to generate the most possible paths given a
600 well-trained model and input vector. Greedy algorithm is used for pruning the
601 search tree.
602 )DOC")
603  .Arg(
604  "tree",
605  "Serialized TreeProto string containing a tree "
606  "including all intermidate nodes and leafs. All nodes must have names "
607  "for correct outputs")
608  .Arg(
609  "beam",
610  "beam used for pruning tree. The pruning algorithm is that "
611  "only children, whose score is smaller than parent's score puls beam, "
612  "will be propagated. ")
613  .Arg("topN", "Number of nodes in outputs")
614  .Input(0, "X", "Input data from previous layer")
615  .Input(1, "W", "The matrix trained from Softmax Ops")
616  .Input(2, "b", "The bias traiend from Softmax Ops")
617  .Output(
618  0,
619  "Y_names",
620  "The name of selected nodes and leafs. "
621  "For nodes, it will be the name defined in the tree. "
622  "For leafs, it will be the index of the word in the tree.")
623  .Output(1, "Y_scores", "The corresponding scores of Y_names");
624 SHOULD_NOT_DO_GRADIENT(HSoftmaxSearch);
625 
626 OPERATOR_SCHEMA(HuffmanTreeHierarchy)
627  .NumInputs(1)
628  .NumOutputs(1)
629  .SetDoc(R"DOC(
630 HuffmanTreeHierarchy is an operator to generate huffman tree hierarchy given
631 the input labels. It returns the tree as seralized HierarchyProto
632 )DOC")
633  .Arg("num_classes", "The number of classes used to build the hierarchy.")
634  .Input(0, "Labels", "The labels vector")
635  .Output(0, "Hierarch", "Huffman coding hierarchy of the labels");
636 
637 SHOULD_NOT_DO_GRADIENT(HuffmanTreeHierarchyOp);
638 } // namespace
639 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...