Caffe2 - C++ API
A deep learning, cross platform ML framework
merge_id_lists_op.h
1 #ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
2 #define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
3 
4 #include <set>
5 #include <vector>
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class MergeIdListsOp : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
16 
17  template <typename T>
18  bool DoRunWithType() {
19  auto& first_lengths = Input(0);
20  CAFFE_ENFORCE_EQ(first_lengths.ndim(), 1, "LENGTHS should be 1-D");
21  const auto batch_size = first_lengths.size();
22 
23  auto* out_lengths = Output(0);
24  out_lengths->ResizeLike(first_lengths);
25 
26  auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
27 
32  auto M = 0;
33  for (size_t i = 0; i < InputSize(); i += 2) {
34  auto& lengths = Input(i);
35  CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS should be 1-D");
36  CAFFE_ENFORCE_EQ(lengths.size(), batch_size, "LENGTHS should be equal");
37  auto& values = Input(i + 1);
38  CAFFE_ENFORCE_EQ(values.ndim(), 1, "VALUES should be 1-D");
39  M += values.size();
40  }
41 
42  auto* out_values = Output(1);
43  out_values->Resize(M);
44 
45  T* out_values_data = out_values->template mutable_data<T>();
46  auto pos = 0;
47 
48  // TODO(badri): Use unordered_set if performance is an issue
49  std::set<T> deduped;
50  std::vector<int> offsets(InputSize(), 0);
51  for (auto sample = 0; sample < batch_size; sample++) {
52  for (size_t i = 0; i < InputSize(); i += 2) {
53  auto& lengths = Input(i);
54  const auto* lengths_data = lengths.template data<int32_t>();
55 
56  auto& values = Input(i + 1);
57  const T* values_data = values.template data<T>();
58  const auto length = lengths_data[sample];
59 
60  for (auto j = offsets[i]; j < offsets[i] + length; j++) {
61  deduped.insert(values_data[j]);
62  }
63  offsets[i] += length;
64  }
65  for (auto val : deduped) {
66  out_values_data[pos++] = val;
67  }
68  out_lengths_data[sample] = deduped.size();
69  deduped.clear();
70  }
71  out_values->Resize(pos);
72  return true;
73  }
74 
75  bool RunOnDevice() override {
76  return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(1));
77  }
78 };
79 
80 } // namespace caffe2
81 
82 #endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...