Caffe2 - C++ API
A deep learning, cross platform ML framework
reshape_op.h
1 #ifndef CAFFE2_OPERATORS_RESHAPE_OP_H_
2 #define CAFFE2_OPERATORS_RESHAPE_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 // Takes a shape and data tensor and reshapes it
13 template <typename F, class Context>
14 class ReshapeOp : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17  ReshapeOp(const OperatorDef& operator_def, Workspace* ws)
18  : Operator<Context>(operator_def, ws),
19  new_shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {}
20 
21  bool RunOnDevice() override {
22  if (InputSize() == 2) {
23  return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
24  }
25  CAFFE_ENFORCE(
26  OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
27  return this->template DoRunWithType<int64_t>();
28  }
29 
30  template <typename T>
31  bool DoRunWithType() {
32  DoRunWithTypeImpl<T>(Input(0), Output(0));
33  return true;
34  }
35 
36  protected:
37  template <typename T>
38  void DoRunWithTypeImpl(
39  const Tensor<Context>& input,
40  Tensor<Context>* output) {
41  vector<int64_t> actual_new_shape = new_shape_;
42  if (InputSize() == 2) {
43  CAFFE_ENFORCE(
44  !OperatorBase::HasArgument("shape"),
45  "New shape is specified by the input blob, do not pass in "
46  "the argument `shape`.");
47 
48  auto& shape = Input(1);
49  CAFFE_ENFORCE(shape.ndim() == 1, "Shape should be 1-D");
50 
51  const T* shape_data = shape.template data<T>();
52 
53  // Bit awkward, but needed so works on both CPU and CUDA contexts
54  std::vector<T> tmpv(shape.size());
55  context_.template CopyBytes<Context, CPUContext>(
56  shape.size() * sizeof(T), shape_data, &tmpv[0]);
57  actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.size());
58  }
59 
60  // Copy over the dimensions for those that are specified zero.
61  for (int i = 0; i < actual_new_shape.size(); ++i) {
62  if (actual_new_shape[i] == 0) {
63  actual_new_shape[i] = input.dim(i);
64  }
65  }
66 
67  // Checks if the new shape is valid and fills in the missing dimension
68  // specified by -1.
69  // NOTE: At most one dimension can be -1.
70  auto total_size = input.size_from_dim(0);
71  T size = 1;
72  int unknown_idx = -1;
73  for (int i = 0; i < actual_new_shape.size(); ++i) {
74  const auto dim = actual_new_shape[i];
75  if (dim == -1) {
76  CAFFE_ENFORCE(
77  unknown_idx == -1,
78  "Argument `shape` has more than one missing dimension.");
79  unknown_idx = i;
80  } else {
81  size *= dim;
82  }
83  }
84 
85  if (unknown_idx != -1) {
86  CAFFE_ENFORCE(
87  total_size % size == 0,
88  "Argument `shape` does not agree with the input data.",
89  " (",
90  total_size,
91  " vs ",
92  size,
93  ")");
94  actual_new_shape[unknown_idx] = total_size / size;
95  } else {
96  CAFFE_ENFORCE_EQ(
97  total_size,
98  size,
99  "Argument `shape` does not agree with the input data.",
100  " (",
101  total_size,
102  " != ",
103  size,
104  ")");
105  }
106 
107  // Write the original shape to the second output.
108  auto* old_shape = Output(1);
109  old_shape->Resize(input.ndim());
110  T* old_shape_data = old_shape->template mutable_data<T>();
111  for (int i = 0; i < input.ndim(); ++i) {
112  math::Set<T, Context>(1, input.dim(i), old_shape_data + i, &context_);
113  }
114 
115  output->Resize(actual_new_shape);
116  if (output != &input) {
117  // If we are not doing in-place computation, a copy is needed.
118  context_.template CopyItems<Context, Context>(
119  input.meta(),
120  input.size(),
121  input.raw_data(),
122  output->raw_mutable_data(input.meta()));
123  }
124  }
125 
126  private:
127  vector<int64_t> new_shape_;
128 };
129 
130 } // namespace caffe2
131 
132 #endif // CAFFE2_OPERATORS_RESHAPE_OP_H_
const TypeMeta & meta() const
Returns the TypeMeta object associated with the current data type.
Definition: tensor.h:648
TIndex dim(const int i) const
Returns the i-th dimension of the tensor.
Definition: tensor.h:671
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.
Definition: tensor.h:472
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37
int ndim() const
Returns the number of dimensions of the data.
Definition: tensor.h:589
void * raw_mutable_data(const TypeMeta &meta)
Returns a mutable raw pointer of the underlying storage.
Definition: tensor.h:510