5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 10 template <
typename F,
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 col_ids_(OperatorBase::GetRepeatedArgument<int>(
"col_ids")),
19 OperatorBase::GetRepeatedArgument<int>(
"categorical_limits")),
20 vals_(OperatorBase::GetRepeatedArgument<int>(
"vals")) {
21 col_num_ = col_ids_.size();
22 max_col_id_ = *std::max_element(col_ids_.begin(), col_ids_.end());
23 CAFFE_ENFORCE_EQ(col_num_, categorical_limits_.size());
24 int expected_vals_size = 0;
25 for (
auto& l : categorical_limits_) {
26 CAFFE_ENFORCE_GT(l, 0);
27 expected_vals_size += l;
29 CAFFE_ENFORCE_EQ(expected_vals_size, vals_.size());
31 for (
auto& j : col_ids_) {
32 CAFFE_ENFORCE_GE(j, 0);
33 ngram_maps_.push_back(std::map<int, int>());
37 for (
int k = 0; k < col_num_; k++) {
38 int l = categorical_limits_[k];
39 for (
int m = 0; m < l; m++) {
41 ngram_maps_[k][v] = m * base;
47 bool RunOnDevice()
override {
48 auto& floats = Input(0);
49 auto N = floats.dim(0);
50 auto D = floats.size_from_dim(1);
51 const F* floats_data = floats.template data<F>();
52 auto* output = Output(0);
54 auto* output_data = output->template mutable_data<T>();
55 math::Set<T, Context>(output->size(), 0, output_data, &context_);
57 CAFFE_ENFORCE_GT(D, max_col_id_);
58 for (
int i = 0; i < N; i++) {
59 for (
int k = 0; k < col_num_; k++) {
61 int v = round(floats_data[i * D + j]);
66 output_data[i] += ngram_maps_[k].find(v) == ngram_maps_[k].end()
75 std::vector<int> col_ids_;
76 std::vector<int> categorical_limits_;
77 std::vector<int> vals_;
78 std::vector<std::map<int, int>> ngram_maps_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...