Caffe2 - C++ API
A deep learning, cross platform ML framework
slice_op.cc
1 #include "caffe2/operators/slice_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(Slice, SliceOp<int, CPUContext>);
7 REGISTER_CPU_OPERATOR(SliceGradient, SliceGradientOp<int, CPUContext>);
8 
9 OPERATOR_SCHEMA(Slice)
10  .NumInputs(1, 3)
11  .NumOutputs(1)
12  .SetDoc(R"DOC(
13 Produces a slice of the input tensor. Currently, only slicing in a single
14 dimension is supported.
15 Slices are passed as 2 1D vectors or as two keyword argument lists with starting
16 and end indices for each dimension of the input `data` tensor. If a negative
17 value is passed for any of the start or end indices, it represents the number of
18 elements before the end of that dimension. End indices are non-inclusive unless
19 negative (end index -1 means up to and including the last element).
20 
21 
22 Example:
23 
24  data = [
25  [1, 2, 3, 4],
26  [5, 6, 7, 8],
27  ]
28  starts = [0, 1]
29  ends = [-1, 3]
30 
31  result = [
32  [2, 3],
33  [6, 7],
34  ]
35 )DOC")
36  .Input(0, "data", "Tensor of data to extract slices from.")
37  .Input(1, "starts", "1D tensor: start-indices for each dimension of data.")
38  .Input(2, "ends", "1D tensor: end-indices for each dimension of data.")
39  .Arg("starts", "List of starting indices")
40  .Arg("ends", "List of ending indices")
41  .TensorInferenceFunction([](const OperatorDef& def,
42  const vector<TensorShape>& in) {
43  if (in.size() > 1) {
44  // Cannot compute shape inference when the splits are defined
45  // in data.
46  return vector<TensorShape>();
47  }
48  auto const& data = in[0];
49 
50  ArgumentHelper helper(def);
51  auto starts = helper.GetRepeatedArgument<int>("starts", vector<int>());
52  auto ends = helper.GetRepeatedArgument<int>("ends", vector<int>());
53  vector<int> dst_sizes(data.dims_size());
54 
55  for (int i = 0; i < data.dims_size(); ++i) {
56  if (i >= starts.size()) {
57  continue;
58  }
59  if (data.dims_size() > 0) {
60  auto start = starts[i];
61  auto end = ends[i];
62  if (start < 0) {
63  start = data.dims(i) + 1 + start;
64  }
65  if (end < 0) {
66  end = data.dims(i) + 1 + end;
67  }
68  dst_sizes[i] = end - start;
69  } else {
70  dst_sizes[i] = 0;
71  }
72  }
73  return vector<TensorShape>{
74  CreateTensorShape(dst_sizes, data.data_type())};
75  })
76  .Output(0, "output", "Sliced data tensor.")
77  .InheritOnnxSchema("Slice");
78 
79 OPERATOR_SCHEMA(SliceGradient);
80 
81 namespace {
82 struct GetSliceGradient : public GradientMakerBase {
83  using GradientMakerBase::GradientMakerBase;
84  vector<OperatorDef> GetGradientDefs() override {
85  if (def_.input_size() > 1) {
86  return vector<OperatorDef>{CreateOperatorDef(
87  "SliceGradient",
88  "",
89  std::vector<string>{I(0), I(1), I(2), GO(0)},
90  std::vector<string>{GI(0)})};
91  } else {
92  return vector<OperatorDef>{CreateOperatorDef(
93  "SliceGradient",
94  "",
95  std::vector<string>{I(0), GO(0)},
96  std::vector<string>{GI(0)})};
97  }
98  }
99 };
100 }
101 REGISTER_GRADIENT(Slice, GetSliceGradient);
102 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...