Caffe2 - C++ API
A deep learning, cross platform ML framework
slice_op.h
1 
2 #pragma once
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 namespace {
11 
12 template <class SIndex, class Context>
13 bool SliceImpl(
14  Tensor<Context>* output,
15  const Tensor<Context>& data,
16  const Tensor<Context>& starts,
17  const Tensor<Context>& ends,
18  Context* context,
19  Tensor<Context>* gdata = nullptr,
20  const Tensor<Context>* go = nullptr) {
21  bool backward = output == nullptr;
22 
23  auto* starts_data = starts.template data<SIndex>();
24  auto* ends_data = ends.template data<SIndex>();
25 
26  CAFFE_ENFORCE_EQ(starts.ndim(), 1);
27  CAFFE_ENFORCE_EQ(ends.ndim(), 1);
28  CAFFE_ENFORCE_GE(data.ndim(), starts.size());
29  CAFFE_ENFORCE_EQ(starts.size(), ends.size());
30 
31  std::vector<SIndex> starts_idx(data.ndim());
32  std::vector<SIndex> ends_idx(data.ndim());
33  std::vector<SIndex> dst_sizes(data.ndim());
34 
35  for (int i = 0; i < data.ndim(); ++i) {
36  if (i >= starts.size()) {
37  starts_idx[i] = 0;
38  ends_idx[i] = data.dims()[i];
39  continue;
40  }
41  if (data.dims()[i] > 0) {
42  auto start = starts_data[i];
43  auto end = ends_data[i];
44  if (start < 0) {
45  start = data.dims()[i] + 1 + start;
46  }
47  if (end < 0) {
48  end = data.dims()[i] + 1 + end;
49  }
50  if (start > data.dims()[i]) {
51  start = data.dims()[i];
52  }
53  if (end > data.dims()[i]) {
54  end = data.dims()[i];
55  }
56  CAFFE_ENFORCE_GE(start, 0);
57  CAFFE_ENFORCE_GE(end, 0);
58  CAFFE_ENFORCE_GE(end, start);
59  starts_idx[i] = start;
60  ends_idx[i] = end;
61  dst_sizes[i] = end - start;
62  } else {
63  starts_idx[i] = 0;
64  ends_idx[i] = 0;
65  dst_sizes[i] = 0;
66  }
67  }
68 
69  if (data.size() <= 0) {
70  // When the input is empty, we do not need to do copy.
71  if (!backward) {
72  output->Resize(dst_sizes);
73  output->raw_mutable_data(data.meta());
74  }
75  return true;
76  }
77  // for now only supports slicing in 1 dimension
78  int dim = -1;
79  for (int i = 0; i < data.ndim(); ++i) {
80  if (starts_idx[i] > 0 || ends_idx[i] < data.dims()[i]) {
81  CAFFE_ENFORCE_EQ(
82  dim, -1, "Currently only possible to slice in 1 dimension.");
83  dim = i;
84  }
85  }
86  if (dim == -1) {
87  if (!backward) {
88  output->CopyFrom(data, context);
89  } else {
90  gdata->CopyFrom(*go, context);
91  }
92  return true;
93  }
94  size_t unit = std::accumulate(
95  data.dims().begin() + dim + 1,
96  data.dims().end(),
97  1,
98  std::multiplies<SIndex>());
99  size_t num_blocks = std::accumulate(
100  data.dims().begin(),
101  data.dims().begin() + dim,
102  1,
103  std::multiplies<SIndex>());
104  if (!backward) {
105  output->Resize(dst_sizes);
106  } else {
107  gdata->ResizeLike(data);
108  }
109 
110  size_t itemsize = data.meta().itemsize();
111 
112  if (!backward) {
113  char* src_bytes = (char*)data.raw_data();
114  char* dst_bytes = (char*)output->raw_mutable_data(data.meta());
115 
116  size_t src_nbytes = data.nbytes();
117  size_t dst_nbytes = output->nbytes();
118 
119  size_t src_block_size = unit * data.dims()[dim];
120  size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
121  size_t src_offset = unit * starts_idx[dim];
122 
123  if (num_blocks == 0 || dst_block_size == 0) {
124  return true;
125  }
126 
127  size_t src_block_size_bytes = itemsize * src_block_size;
128  size_t dst_block_size_bytes = itemsize * dst_block_size;
129 
130  char* src_offset_bytes = src_bytes + itemsize * src_offset;
131  char* dst_offset_bytes = dst_bytes;
132  for (int i = 0; i < num_blocks; ++i) {
133  char* local_src_offset_bytes =
134  src_offset_bytes + i * src_block_size_bytes;
135  char* local_dst_offset_bytes =
136  dst_offset_bytes + i * dst_block_size_bytes;
137  DCHECK_LE(
138  static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
139  static_cast<void*>(src_bytes + src_nbytes));
140  DCHECK_LE(
141  static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
142  static_cast<void*>(dst_bytes + dst_nbytes));
143  context->template CopyItems<Context, Context>(
144  data.meta(),
145  dst_block_size,
146  (void*)local_src_offset_bytes,
147  (void*)local_dst_offset_bytes);
148  }
149  } else {
150  char* src_bytes = (char*)go->raw_data();
151  char* dst_bytes = (char*)gdata->raw_mutable_data(go->meta());
152 
153  size_t src_nbytes = go->nbytes();
154  size_t dst_nbytes = gdata->nbytes();
155 
156  size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
157  size_t dst_block_size = unit * data.dims()[dim];
158  size_t dst_offset = unit * starts_idx[dim];
159 
160  if (num_blocks == 0 || dst_block_size == 0) {
161  return true;
162  }
163 
164  size_t src_block_size_bytes = itemsize * src_block_size;
165  size_t dst_block_size_bytes = itemsize * dst_block_size;
166 
167  char* src_offset_bytes = src_bytes;
168  char* dst_offset_bytes = dst_bytes + itemsize * dst_offset;
169  // Zero out gradient blob before copy since we copy in fewer items than
170  // there is space for
171  math::Set<char, Context>(dst_nbytes, 0, dst_bytes, context);
172 
173  // If output tensor is empty, just return zeroed gradient tensor
174  if (!src_bytes) {
175  return true;
176  }
177 
178  for (int i = 0; i < num_blocks; ++i) {
179  char* local_src_offset_bytes =
180  src_offset_bytes + i * src_block_size_bytes;
181  char* local_dst_offset_bytes =
182  dst_offset_bytes + i * dst_block_size_bytes;
183  DCHECK_LE(
184  local_src_offset_bytes + src_block_size_bytes,
185  src_bytes + src_nbytes);
186  DCHECK_LE(
187  local_dst_offset_bytes + src_block_size_bytes,
188  dst_bytes + dst_nbytes);
189  context->template CopyItems<Context, Context>(
190  go->meta(),
191  src_block_size,
192  (void*)local_src_offset_bytes,
193  (void*)local_dst_offset_bytes);
194  }
195  }
196  return true;
197 }
198 
199 } // namespace
200 
201 template <class SIndex, class Context>
202 class SliceOp : public Operator<Context> {
203  public:
204  USE_OPERATOR_CONTEXT_FUNCTIONS;
205  SliceOp(const OperatorDef& operator_def, Workspace* ws)
206  : Operator<Context>(operator_def, ws),
207  starts_(OperatorBase::GetRepeatedArgument<SIndex>("starts")),
208  ends_(OperatorBase::GetRepeatedArgument<SIndex>("ends")),
209  statically_inited_(false) {}
210 
211  bool RunOnDevice() override {
212  auto* output = Output(0);
213  auto& data = Input(0);
214 
215  if (InputSize() > 1) {
216  starts_host_.template CopyFrom<Context>(Input(1));
217  ends_host_.template CopyFrom<Context>(Input(2));
218  } else {
219  if (!statically_inited_) {
220  CAFFE_ENFORCE(HasArgument("starts"));
221  CAFFE_ENFORCE(HasArgument("ends"));
222  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
223 
224  starts_host_.Resize(starts_.size());
225  ends_host_.Resize(ends_.size());
226 
227  memcpy(
228  starts_host_.template mutable_data<SIndex>(),
229  starts_.data(),
230  sizeof(SIndex) * starts_.size());
231  memcpy(
232  ends_host_.template mutable_data<SIndex>(),
233  ends_.data(),
234  sizeof(SIndex) * ends_.size());
235  statically_inited_ = true;
236  }
237  }
238 
239  return SliceImpl<SIndex, Context>(
240  output, data, starts_host_, ends_host_, &context_);
241  }
242 
243  DISABLE_COPY_AND_ASSIGN(SliceOp);
244 
245  private:
246  std::vector<SIndex> starts_;
247  std::vector<SIndex> ends_;
248  bool statically_inited_;
249  TensorCPU starts_host_;
250  TensorCPU ends_host_;
251 };
252 
253 template <class SIndex, class Context>
254 class SliceGradientOp : public Operator<Context> {
255  public:
256  USE_OPERATOR_CONTEXT_FUNCTIONS;
257  SliceGradientOp(const OperatorDef& operator_def, Workspace* ws)
258  : Operator<Context>(operator_def, ws),
259  starts_(OperatorBase::GetRepeatedArgument<SIndex>("starts")),
260  ends_(OperatorBase::GetRepeatedArgument<SIndex>("ends")),
261  statically_inited_(false) {}
262 
263  bool RunOnDevice() override {
264  auto* gdata = Output(0);
265  auto& data = Input(0);
266 
267  if (InputSize() == 4) {
268  starts_host_.template CopyFrom<Context>(Input(1));
269  ends_host_.template CopyFrom<Context>(Input(2));
270 
271  auto& go = Input(3);
272 
273  return SliceImpl<SIndex, Context>(
274  nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
275  } else {
276  if (!statically_inited_) {
277  CAFFE_ENFORCE(HasArgument("starts"));
278  CAFFE_ENFORCE(HasArgument("ends"));
279  CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
280 
281  starts_host_.Resize(starts_.size());
282  ends_host_.Resize(ends_.size());
283 
284  memcpy(
285  starts_host_.template mutable_data<SIndex>(),
286  starts_.data(),
287  sizeof(SIndex) * starts_.size());
288  memcpy(
289  ends_host_.template mutable_data<SIndex>(),
290  ends_.data(),
291  sizeof(SIndex) * ends_.size());
292 
293  statically_inited_ = true;
294  }
295  auto& go = Input(1);
296 
297  return SliceImpl<SIndex, Context>(
298  nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
299  }
300  }
301 
302  DISABLE_COPY_AND_ASSIGN(SliceGradientOp);
303 
304  private:
305  std::vector<SIndex> starts_;
306  std::vector<SIndex> ends_;
307  bool statically_inited_;
308  TensorCPU starts_host_;
309  TensorCPU ends_host_;
310 };
311 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37