Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.cc
1 #include "caffe2/operators/reshape_op.h"
2 #include "caffe2/utils/math.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(Reshape, ReshapeOp<float, CPUContext>);
7 
8 OPERATOR_SCHEMA(Reshape)
9  .NumInputs(1, 2)
10  .NumOutputs(2)
11  .TensorInferenceFunction(
12  [](const OperatorDef& def, const vector<TensorShape>& in) {
13  vector<TensorShape> out(2);
14 
15  // Do shape inference for old_shape
16  out[1].set_data_type(TensorProto::INT64);
17  out[1].add_dims(in[0].dims_size());
18 
19  ArgumentHelper helper(def);
20  if (!helper.HasArgument("shape")) {
21  // Cannot do shape inference for reshaped tensor from runtime data.
22  CAFFE_ENFORCE_EQ(
23  in.size(),
24  2,
25  "New shape must be specified by either the input blob or the "
26  "argument `shape`.");
27  out[0].set_unknown_shape(true);
28  return out;
29  }
30  CAFFE_ENFORCE_EQ(
31  in.size(),
32  1,
33  "New shape must not be specified by the input blob and the "
34  "argument `shape` at the same time.");
35 
36  // Infer the actual new shape
37  auto actualNewShape = helper.GetRepeatedArgument<int64_t>("shape");
38 
39  // Copy over the dimensions for those that are specified zero
40  // and check the eligibility of input
41  for (int i = 0; i < actualNewShape.size(); ++i) {
42  CAFFE_ENFORCE_GE(
43  actualNewShape[i],
44  -1,
45  "The dimensions in argument `shape` "
46  "must not be a negative number.");
47 
48  if (actualNewShape[i] == 0) {
49  CAFFE_ENFORCE_LT(
50  i,
51  in[0].dims_size(),
52  "Argument `shape` has a dimension set to zero that exceeds "
53  "the original dimension size.");
54  actualNewShape[i] = in[0].dims(i);
55  }
56  }
57 
58  // Check if the new shape is valid and fills in the missing dimension
59  // specified by -1.
60  int64_t totalSize = 1;
61  for (const auto d : in[0].dims()) {
62  totalSize *= d;
63  }
64  int64_t size = 1;
65  int unknownIdx = -1;
66  for (int i = 0; i < actualNewShape.size(); ++i) {
67  const auto dim = actualNewShape[i];
68  if (dim == -1) {
69  CAFFE_ENFORCE(
70  unknownIdx == -1,
71  "Argument `shape` has more than one missing dimension.");
72  unknownIdx = i;
73  } else {
74  size *= dim;
75  }
76  }
77 
78  if (unknownIdx != -1) {
79  CAFFE_ENFORCE(
80  totalSize % size == 0,
81  "Argument `shape` does not agree with the input data.",
82  " (",
83  totalSize,
84  " vs ",
85  size,
86  ")");
87  actualNewShape[unknownIdx] = totalSize / size;
88  } else {
89  CAFFE_ENFORCE_EQ(
90  totalSize,
91  size,
92  "Argument `shape` does not agree with the input data.",
93  " (",
94  totalSize,
95  " != ",
96  size,
97  ")");
98  }
99 
100  out[0].set_data_type(in[0].data_type());
101  for (const auto d : actualNewShape) {
102  out[0].add_dims(d);
103  }
104  return out;
105  })
106  .AllowInplace({{0, 0}})
107  .SetDoc(R"DOC(
108 Reshape the input tensor similar to numpy.reshape.
109 
110 It takes a tensor as input and an optional tensor specifying the new shape.
111 When the second input is absent, an extra argument `shape` must be specified.
112 It outputs the reshaped tensor as well as the original shape.
113 
114 At most one dimension of the new shape can be -1. In this case, the value is
115 inferred from the size of the tensor and the remaining dimensions. A dimension
116 could also be 0, in which case the actual dimension value is going to be copied
117 from the input tensor.
118 )DOC")
119  .Arg("shape", "New shape")
120  .Input(0, "data", "An input tensor.")
121  .Input(1, "new_shape", "New shape.")
122  .Output(0, "reshaped", "Reshaped data.")
123  .Output(1, "old_shape", "Original shape.")
124  .InheritOnnxSchema("Reshape");
125 
127  using GradientMakerBase::GradientMakerBase;
128  vector<OperatorDef> GetGradientDefs() override {
129  return SingleGradientDef(
130  "Reshape",
131  "",
132  vector<string>{GO(0), O(1)},
133  vector<string>{GI(0), "_" + GI(0) + "_dims"});
134  }
135 
136  // Argument `shape` is no longer needed in backprop.
137  bool CopyArguments() const override {
138  return false;
139  }
140 };
141 
142 REGISTER_GRADIENT(Reshape, GetReshapeGradient);
143 
144 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...