Caffe2 - C++ API
A deep learning, cross platform ML framework
ngram_ops.h
1 #pragma once
2 
3 #include <vector>
4 
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 template <typename F, typename T, class Context>
11 class NGramFromCategoricalOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14 
15  NGramFromCategoricalOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  col_ids_(OperatorBase::GetRepeatedArgument<int>("col_ids")),
18  categorical_limits_(
19  OperatorBase::GetRepeatedArgument<int>("categorical_limits")),
20  vals_(OperatorBase::GetRepeatedArgument<int>("vals")) {
21  col_num_ = col_ids_.size();
22  max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end());
23  CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size());
24  int expected_vals_size = 0;
25  for (auto& l : categorical_limits_) {
26  CAFFE_ENFORCE_GT(l, 0);
27  expected_vals_size += l;
28  }
29  CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size());
30  // compute ngram maps with small end
31  for (auto& j : col_ids_) {
32  CAFFE_ENFORCE_GE(j, 0);
33  ngram_maps_.push_back(std::map<int, int>());
34  }
35  int base = 1;
36  int idx = 0;
37  for (int k = 0; k < col_num_; k++) {
38  int l = categorical_limits_[k];
39  for (int m = 0; m < l; m++) {
40  int v = vals_[idx++];
41  ngram_maps_[k][v] = m * base;
42  }
43  base *= l;
44  }
45  }
46 
47  bool RunOnDevice() override {
48  auto& floats = Input(0);
49  auto N = floats.dim(0);
50  auto D = floats.size_from_dim(1);
51  const F* floats_data = floats.template data<F>();
52  auto* output = Output(0);
53  output->Resize(N);
54  auto* output_data = output->template mutable_data<T>();
55  math::Set<T, Context>(output->size(), 0, output_data, &context_);
56 
57  CAFFE_ENFORCE_GT(D, max_col_id_);
58  for (int i = 0; i < N; i++) {
59  for (int k = 0; k < col_num_; k++) {
60  int j = col_ids_[k];
61  int v = round(floats_data[i * D + j]);
62  // for out-of-vocabulary values, we always treat them the same as the
63  // first value specified in vals; if we want to mimic the behavior as
64  // sigrid NGram transform, just push front a random/impossible value at
65  // each segments of vals
66  output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end()
67  ? 0
68  : ngram_maps_[k][v];
69  }
70  }
71  return true;
72  }
73 
74  private:
75  std::vector<int> col_ids_;
76  std::vector<int> categorical_limits_;
77  std::vector<int> vals_;
78  std::vector<std::map<int, int>> ngram_maps_;
79  int col_num_;
80  int max_col_id_;
81 };
82 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...