Caffe2 - C++ API
A deep learning, cross platform ML framework
gather_ranges_to_dense_op.h
1 #ifndef CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
2 #define CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
3 
4 #include <math.h>
5 
6 #include "caffe2/core/common_omp.h"
7 #include "caffe2/core/context.h"
8 #include "caffe2/core/logging.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/types.h"
11 #include "caffe2/utils/math.h"
12 
13 #include <map>
14 #include <utility>
15 
16 namespace caffe2 {
17 template <class Context>
18 class GatherRangesToDenseOp final : public Operator<Context> {
19  public:
20  USE_OPERATOR_CONTEXT_FUNCTIONS;
21  GatherRangesToDenseOp(const OperatorDef& operator_def, Workspace* ws)
22  : Operator<Context>(operator_def, ws),
23  lengths_(OperatorBase::GetRepeatedArgument<int>("lengths")) {
24  CAFFE_ENFORCE_GT(lengths_.size(), 0, "There has to be at least one length");
25  for (auto length : lengths_) {
26  CAFFE_ENFORCE_GT(length, 0, "Each length should be positive");
27  }
28  }
29 
30  bool RunOnDevice() override {
32  this, OperatorBase::Input<TensorCPU>(RANGES));
33  }
34 
35  template <typename Index>
36  bool DoRunWithType() {
37  auto& data = Input(DATA);
38  auto& ranges = Input(RANGES);
39  CAFFE_ENFORCE_EQ(data.ndim(), 1, "Data has to be 1-D");
40  CAFFE_ENFORCE_EQ(ranges.ndim(), 3, "Ranges has to be 3-D");
41  if (InputSize() == 3) {
42  auto& key = Input(KEY);
43  CAFFE_ENFORCE_EQ(key.ndim(), 1, "Key has to be 1-D");
44  CAFFE_ENFORCE(
45  key.meta().template Match<int64_t>(), "Key has to be type int64_t");
46  }
47  CAFFE_ENFORCE_EQ(
48  ranges.dim(1),
49  lengths_.size(),
50  "Nummber of ranges should match number of lengths");
51  CAFFE_ENFORCE_EQ(
52  ranges.dim(1),
53  OutputSize(),
54  "Nummber of ranges should match number of outputs");
55  CAFFE_ENFORCE_EQ(
56  ranges.dim(2), 2, "Ranges last dimension should be of size 2");
57 
58  auto* rawData = static_cast<const char*>(data.raw_data());
59  auto* rangesData = ranges.template data<Index>();
60  int rangesDataOffset = 0;
61  auto itemsize = data.meta().itemsize();
62 
63  auto batchSize = ranges.dim(0);
64  vector<TIndex> outputDims{batchSize, 0};
65  vector<char*> outputRawData;
66  for (int i = 0; i < OutputSize(); ++i) {
67  auto* output = Output(i);
68  outputDims[1] = lengths_[i];
69  output->Resize(outputDims);
70  char* ptr = static_cast<char*>(output->raw_mutable_data(data.meta()));
71  memset(ptr, 0, output->nbytes());
72  outputRawData.push_back(ptr);
73  }
74 
75  for (int i = 0; i < batchSize; ++i) {
76  for (int j = 0; j < OutputSize(); ++j) {
77  auto rangeStart = rangesData[rangesDataOffset++];
78  auto rangeLength = rangesData[rangesDataOffset++];
79  if (rangeLength == 0) {
80  // empty range, will be filled with zeros
81  continue;
82  }
83  CAFFE_ENFORCE_EQ(
84  rangeLength,
85  lengths_[j],
86  "Range lengths missmatch for output #",
87  j);
88 
89  if (InputSize() == 2) {
90  context_.template CopyItems<Context, Context>(
91  data.meta(),
92  rangeLength,
93  rawData + rangeStart * itemsize,
94  outputRawData[j] + i * itemsize * lengths_[j]);
95  } else {
96  auto& key = Input(KEY);
97  auto* key_data = key.template data<int64_t>();
98  vector<std::pair<int64_t, const char*>> buffer;
99  for (int b_i = 0; b_i < rangeLength; ++b_i) {
100  int64_t one_key_item = key_data[rangeStart + b_i];
101  auto* one_data_item = rawData + (rangeStart + b_i) * itemsize;
102  buffer.emplace_back(one_key_item, one_data_item);
103  }
104  std::sort(
105  buffer.begin(),
106  buffer.end(),
107  [](const std::pair<int64_t, const char*>& left,
108  const std::pair<int64_t, const char*>& right) {
109  return left.first < right.first;
110  });
111  for (int b_i = 0; b_i < rangeLength; ++b_i) {
112  // Since this CPU only, directly copy to the destination.
113  std::memcpy(
114  outputRawData[j] + (i * lengths_[j] + b_i) * itemsize,
115  buffer[b_i].second,
116  itemsize);
117  }
118  }
119  }
120  }
121  CAFFE_ENFORCE_EQ(rangesDataOffset, ranges.size());
122 
123  return true;
124  }
125 
126  INPUT_TAGS(DATA, RANGES, KEY);
127 
128  private:
129  vector<int> lengths_;
130 };
131 
132 } // namespace caffe2
133 
134 #endif // CAFFE2_OPERATORS_GATHER_RANGES_TO_DENSE_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...