Caffe2 - C++ API
A deep learning, cross platform ML framework
tile_op.h
1 #ifndef CAFFE2_OPERATORS_TILE_OP_H_
2 #define CAFFE2_OPERATORS_TILE_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 // Copy a Blob n times along a specified axis.
13 template <class Context>
14 class TileOp : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17  TileOp(const OperatorDef& operator_def, Workspace* ws)
18  : Operator<Context>(operator_def, ws),
19  tiles_(OperatorBase::GetSingleArgument<int32_t>("tiles", 1)),
20  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 0)) {}
21  ~TileOp() {}
22 
23  bool RunOnDevice() override {
24  const auto& input = Input(0);
25  std::array<int32_t, 2> temp_params = {{tiles_, axis_}};
26  if (InputSize() > 1) {
27  // We potentially have tiles and/or axis specified as inputs
28  // as well. We will check for them in that order. In other words:
29  // InputSize() == 2: tiles is specified
30  // InputSize() == 3: tiles is specified and axis.
31  // Anything specified as input will override the arguments
32  CAFFE_ENFORCE(
33  Input(1).ndim() == 1 && Input(1).size() == 1,
34  "Input `tiles` should be a vector of size 1.");
35 
36  const auto& input1 = Input(1);
37  context_.template CopyItems<Context, CPUContext>(
38  input1.meta(),
39  1,
40  static_cast<const char*>(input1.raw_data()),
41  &(temp_params[0]));
42 
43  if (InputSize() > 2) {
44  CAFFE_ENFORCE(
45  Input(2).ndim() == 1 && Input(2).size() == 1,
46  "Input `axis` should be a vector of size 1.");
47 
48  const auto& input2 = Input(2);
49  context_.template CopyItems<Context, CPUContext>(
50  input2.meta(),
51  1,
52  static_cast<const char*>(input2.raw_data()),
53  &(temp_params[1]));
54  } else {
55  CAFFE_ENFORCE(
57  "Argument `axis` is missing and was not specified as input.");
58  }
59  } else {
60  CAFFE_ENFORCE(
62  "Argument `tiles` is missing and was not specified as input.");
63  CAFFE_ENFORCE(
65  "Argument `axis` is missing and was not specified as input.");
66  }
67 
68  tiles_ = temp_params[0];
69  axis_ = temp_params[1];
70 
71  auto* output = Output(0);
72  const auto axis = input.canonical_axis_index(axis_);
73 
74  // reshape output to be input tiled along the axis
75  vector<TIndex> output_dims(input.dims());
76  output_dims[axis_] = output_dims[axis_] * tiles_;
77  output->Resize(output_dims);
78 
79  // size up to (and not including) axis
80  const auto outer_dim = input.size_to_dim(axis);
81  // size from axis up
82  const auto inner_dim = input.size_from_dim(axis);
83 
92  const char* input_data = static_cast<const char*>(input.raw_data());
93  char* output_data =
94  static_cast<char*>(output->raw_mutable_data(input.meta()));
95 
96  DoTile(
97  input.meta(),
98  input.itemsize(),
99  outer_dim,
100  inner_dim,
101  input_data,
102  output_data);
103 
104  return true;
105  }
106 
107  private:
108  void DoTile(
109  const TypeMeta& meta,
110  int item_size,
111  int outer_dim,
112  int inner_dim,
113  const char* input_data,
114  char* output_data) {
115  for (auto i = 0; i < outer_dim; ++i) {
116  for (auto t = 0; t < tiles_; ++t) {
117  context_.template CopyItems<Context, Context>(
118  meta, inner_dim, input_data, output_data);
119  output_data += inner_dim * item_size;
120  }
121  input_data += inner_dim * item_size;
122  }
123  }
124 
125  int32_t tiles_;
126  int32_t axis_;
127 };
128 
129 template <typename T, class Context>
130 class TileGradientOp : public Operator<Context> {
131  public:
132  USE_OPERATOR_CONTEXT_FUNCTIONS;
133  TileGradientOp(const OperatorDef& operator_def, Workspace* ws)
134  : Operator<Context>(operator_def, ws),
135  tiles_(OperatorBase::GetSingleArgument<int32_t>("tiles", 1)),
136  axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 0)) {}
137  ~TileGradientOp() {}
138 
139  bool RunOnDevice() override {
140  std::array<int32_t, 2> temp_params = {{tiles_, axis_}};
141  if (InputSize() > 1) {
142  // We potentially have tiles and/or axis specified as inputs
143  // as well. We will check for them in that order. In other words:
144  // InputSize() == 2: tiles is specified
145  // InputSize() == 3: tiles is specified and axis.
146  // Anything specified as input will override the arguments
147  CAFFE_ENFORCE(
148  Input(1).ndim() == 1 && Input(1).size() == 1,
149  "Input `tiles` should be a vector of size 1.");
150 
151  const auto& input1 = Input(1);
152  context_.template CopyItems<Context, CPUContext>(
153  input1.meta(),
154  1,
155  static_cast<const char*>(input1.raw_data()),
156  &(temp_params[0]));
157 
158  if (InputSize() > 2) {
159  CAFFE_ENFORCE(
160  Input(2).ndim() == 1 && Input(2).size() == 1,
161  "Input `axis` should be a vector of size 1.");
162 
163  const auto& input2 = Input(2);
164  context_.template CopyItems<Context, CPUContext>(
165  input2.meta(),
166  1,
167  static_cast<const char*>(input2.raw_data()),
168  &(temp_params[1]));
169  } else {
170  CAFFE_ENFORCE(
172  "Argument `axis` is missing and was not specified as input.");
173  }
174  } else {
175  CAFFE_ENFORCE(
176  OperatorBase::HasArgument("tiles"),
177  "Argument `tiles` is missing and was not specified as input.");
178  CAFFE_ENFORCE(
180  "Argument `axis` is missing and was not specified as input.");
181  }
182 
183  tiles_ = temp_params[0];
184  axis_ = temp_params[1];
185 
186  const auto& input = Input(0);
187  auto* output = Output(0);
188  const auto axis = input.canonical_axis_index(axis_);
189 
190  // reshape output to be input "untiled" along the axis
191  vector<TIndex> output_dims(input.dims());
192  output_dims[axis_] = output_dims[axis_] / tiles_;
193  output->Resize(output_dims);
194 
195  // size up to (and not including) axis
196  const auto outer_dim = output->size_to_dim(axis);
197  // size from axis up
198  const auto inner_dim = output->size_from_dim(axis);
199 
210  const char* input_data = static_cast<const char*>(input.raw_data());
211  char* output_data =
212  static_cast<char*>(output->raw_mutable_data(input.meta()));
213 
214  DoTileGradient(
215  input.meta(),
216  input.itemsize(),
217  outer_dim,
218  inner_dim,
219  input_data,
220  output_data);
221 
222  return true;
223  }
224 
225  private:
226  void DoTileGradient(
227  const TypeMeta& meta,
228  int item_size,
229  int outer_dim,
230  int inner_dim,
231  const char* input_data,
232  char* output_data) {
233  for (auto i = 0; i < outer_dim; ++i) {
234  context_.template CopyItems<Context, Context>(
235  meta, inner_dim, input_data, output_data);
236  input_data += inner_dim * item_size;
237  for (auto t = 1; t < tiles_; ++t) {
238  math::Axpy<T, Context>(
239  inner_dim,
240  T(1),
241  reinterpret_cast<const T*>(input_data),
242  reinterpret_cast<T*>(output_data),
243  &context_);
244  input_data += inner_dim * item_size;
245  }
246  output_data += inner_dim * item_size;
247  }
248  }
249 
250  int32_t tiles_;
251  int32_t axis_;
252 };
253 
254 } // namespace caffe2
255 
256 #endif // CAFFE2_OPERATORS_TILE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
bool RunOnDevice() override
Definition: tile_op.h:23
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:88
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37
bool RunOnDevice() override
Definition: tile_op.h:139