1 #include "caffe2/operators/slice_op.h" 2 #include "caffe2/utils/math.h" 6 REGISTER_CPU_OPERATOR(Slice, SliceOp<int, CPUContext>);
7 REGISTER_CPU_OPERATOR(SliceGradient, SliceGradientOp<int, CPUContext>);
13 Produces a slice of the input tensor. Currently, only slicing in a single 14 dimension is supported. 15 Slices are passed as 2 1D vectors or as two keyword argument lists with starting 16 and end indices for each dimension of the input `data` tensor. If a negative 17 value is passed for any of the start or end indices, it represents the number of 18 elements before the end of that dimension. End indices are non-inclusive unless 19 negative (end index -1 means up to and including the last element). 36 .Input(0, "data",
"Tensor of data to extract slices from.")
37 .Input(1,
"starts",
"1D tensor: start-indices for each dimension of data.")
38 .Input(2,
"ends",
"1D tensor: end-indices for each dimension of data.")
39 .Arg(
"starts",
"List of starting indices")
40 .Arg(
"ends",
"List of ending indices")
41 .TensorInferenceFunction([](
const OperatorDef& def,
42 const vector<TensorShape>& in) {
46 return vector<TensorShape>();
48 auto const& data = in[0];
50 ArgumentHelper helper(def);
51 auto starts = helper.GetRepeatedArgument<
int>(
"starts", vector<int>());
52 auto ends = helper.GetRepeatedArgument<
int>(
"ends", vector<int>());
53 vector<int> dst_sizes(data.dims_size());
55 for (
int i = 0; i < data.dims_size(); ++i) {
56 if (i >= starts.size()) {
59 if (data.dims_size() > 0) {
60 auto start = starts[i];
63 start = data.dims(i) + 1 + start;
66 end = data.dims(i) + 1 + end;
68 dst_sizes[i] = end - start;
73 return vector<TensorShape>{
74 CreateTensorShape(dst_sizes, data.data_type())};
76 .Output(0,
"output",
"Sliced data tensor.")
77 .InheritOnnxSchema(
"Slice");
79 OPERATOR_SCHEMA(SliceGradient);
82 struct GetSliceGradient :
public GradientMakerBase {
83 using GradientMakerBase::GradientMakerBase;
84 vector<OperatorDef> GetGradientDefs()
override {
85 if (def_.input_size() > 1) {
86 return vector<OperatorDef>{CreateOperatorDef(
89 std::vector<string>{I(0), I(1), I(2), GO(0)},
90 std::vector<string>{GI(0)})};
92 return vector<OperatorDef>{CreateOperatorDef(
95 std::vector<string>{I(0), GO(0)},
96 std::vector<string>{GI(0)})};
101 REGISTER_GRADIENT(Slice, GetSliceGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...