Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_squeeze_dims_op.cc
1 #include "caffe2/operators/expand_squeeze_dims_op.h"
2 
3 namespace caffe2 {
4 REGISTER_CPU_OPERATOR(ExpandDims, ExpandDimsOp<CPUContext>);
5 REGISTER_CPU_OPERATOR(Squeeze, SqueezeOp<CPUContext>);
6 
7 OPERATOR_SCHEMA(ExpandDims)
8  .NumInputs(1)
9  .NumOutputs(1)
10  .AllowInplace({{0, 0}})
11  .TensorInferenceFunction([](const OperatorDef& def,
12  const vector<TensorShape>& in) {
13  ArgumentHelper helper(def);
14  auto dims = helper.template GetRepeatedArgument<int>("dims");
15  auto originalSize = dims.size();
16  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
17 
18  std::sort(dims.begin(), dims.end());
19  dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
20  if (dims.size() < originalSize) {
21  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
22  }
23 
24  CAFFE_ENFORCE(dims.front() >= 0, "Dimension ids must be non-negative.");
25  CAFFE_ENFORCE_GE(
26  in[0].dims_size() + dims.size(),
27  dims.back() + 1,
28  "Input needs at least ",
29  (1 + dims.back() - dims.size()),
30  " dimensions given `dims`.");
31 
32  vector<TensorShape> out(1);
33 
34  int cur_pos = 0;
35  int idx = 0;
36  for (const auto new_dim : dims) {
37  for (int i = cur_pos; i < new_dim; i++) {
38  out[0].add_dims(in[0].dims(idx++));
39  }
40  out[0].add_dims(1);
41  cur_pos = new_dim + 1;
42  }
43  for (; idx < in[0].dims_size(); idx++) {
44  out[0].add_dims(in[0].dims(idx));
45  }
46  out[0].set_data_type(in[0].data_type());
47  return out;
48  })
49  .SetDoc(R"DOC(
50 Insert single-dimensional entries to the shape of a tensor.
51 Takes one required argument `dims`, a list of dimensions that will be inserted.
52 Dimension indices in `dims` are as seen in the output tensor. For example:
53 
54  Given a tensor such that tensor.Shape() = [3, 4, 5], then
55  ExpandDims(tensor, dims=[0, 4]).Shape() == [1, 3, 4, 5, 1])
56 
57 If the same blob is provided in input and output, the operation is copy-free.
58 )DOC")
59  .Input(0, "data", "Original tensor")
60  .Output(0, "expanded", "Reshaped tensor with same data as input.");
61 
62 OPERATOR_SCHEMA(Squeeze)
63  .NumInputs(1)
64  .NumOutputs(1)
65  .AllowInplace({{0, 0}})
66  .SetDoc(R"DOC(
67 Remove single-dimensional entries from the shape of a tensor.
68 Takes a parameter `dims` with a list of dimension to squeeze.
69 If the same blob is provided in input and output, the operation is copy-free.
70 This is the exact inverse operation of ExpandDims given the same `dims` arg.
71 )DOC")
72  .Input(0, "data", "Tensors with at least max(dims) dimensions.")
73  .Output(0, "squeezed", "Reshaped tensor with same data as input.")
74  .TensorInferenceFunction([](const OperatorDef& def,
75  const vector<TensorShape>& in) {
76  ArgumentHelper helper(def);
77  auto dims = helper.template GetRepeatedArgument<int>("dims");
78  auto originalSize = dims.size();
79  std::sort(dims.begin(), dims.end());
80  dims.erase(std::unique(dims.begin(), dims.end()), dims.end());
81  if (dims.size() < originalSize) {
82  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
83  }
84  CAFFE_ENFORCE(dims.front() >= 0, "Dimension ids must be non-negative.");
85 
86  vector<TensorShape> out(1);
87  std::vector<int> newDims =
88  SqueezeOp<CPUContext>::ComputeDims(GetDimsVector(in[0]), dims);
89  out[0] = CreateTensorShape(newDims, in[0].data_type());
90  return out;
91  })
92  .InheritOnnxSchema("Squeeze");
93 
94 class GetSqueezeGradient : public GradientMakerBase {
95  using GradientMakerBase::GradientMakerBase;
96  vector<OperatorDef> GetGradientDefs() override {
97  return SingleGradientDef(
98  "ExpandDims", "", vector<string>{GO(0)}, vector<string>{GI(0)});
99  }
100 };
101 REGISTER_GRADIENT(Squeeze, GetSqueezeGradient);
102 
103 class GetExpandDimsGradient : public GradientMakerBase {
104  using GradientMakerBase::GradientMakerBase;
105  vector<OperatorDef> GetGradientDefs() override {
106  return SingleGradientDef(
107  "Squeeze", "", vector<string>{GO(0)}, vector<string>{GI(0)});
108  }
109 };
110 REGISTER_GRADIENT(ExpandDims, GetExpandDimsGradient);
111 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...