1 #ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ 2 #define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 11 template <
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 auto& first_lengths = Input(0);
20 CAFFE_ENFORCE_EQ(first_lengths.ndim(), 1,
"LENGTHS should be 1-D");
21 const auto batch_size = first_lengths.size();
23 auto* out_lengths = Output(0);
24 out_lengths->ResizeLike(first_lengths);
26 auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
33 for (
size_t i = 0; i < InputSize(); i += 2) {
34 auto& lengths = Input(i);
35 CAFFE_ENFORCE_EQ(lengths.ndim(), 1,
"LENGTHS should be 1-D");
36 CAFFE_ENFORCE_EQ(lengths.size(), batch_size,
"LENGTHS should be equal");
37 auto& values = Input(i + 1);
38 CAFFE_ENFORCE_EQ(values.ndim(), 1,
"VALUES should be 1-D");
42 auto* out_values = Output(1);
43 out_values->Resize(M);
45 T* out_values_data = out_values->template mutable_data<T>();
50 std::vector<int> offsets(InputSize(), 0);
51 for (
auto sample = 0; sample < batch_size; sample++) {
52 for (
size_t i = 0; i < InputSize(); i += 2) {
53 auto& lengths = Input(i);
54 const auto* lengths_data = lengths.template data<int32_t>();
56 auto& values = Input(i + 1);
57 const T* values_data = values.template data<T>();
58 const auto length = lengths_data[sample];
60 for (
auto j = offsets[i]; j < offsets[i] + length; j++) {
61 deduped.insert(values_data[j]);
65 for (
auto val : deduped) {
66 out_values_data[pos++] = val;
68 out_lengths_data[sample] = deduped.size();
71 out_values->Resize(pos);
75 bool RunOnDevice()
override {
82 #endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...