1 #ifndef CAFFE2_OPERATORS_TILE_OP_H_ 2 #define CAFFE2_OPERATORS_TILE_OP_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 13 template <
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 tiles_(OperatorBase::GetSingleArgument<int32_t>(
"tiles", 1)),
20 axis_(OperatorBase::GetSingleArgument<int32_t>(
"axis", 0)) {}
24 const auto& input = Input(0);
25 std::array<int32_t, 2> temp_params = {{tiles_, axis_}};
26 if (InputSize() > 1) {
33 Input(1).ndim() == 1 && Input(1).size() == 1,
34 "Input `tiles` should be a vector of size 1.");
36 const auto& input1 = Input(1);
37 context_.template CopyItems<Context, CPUContext>(
40 static_cast<const char*
>(input1.raw_data()),
43 if (InputSize() > 2) {
45 Input(2).ndim() == 1 && Input(2).size() == 1,
46 "Input `axis` should be a vector of size 1.");
48 const auto& input2 = Input(2);
49 context_.template CopyItems<Context, CPUContext>(
52 static_cast<const char*
>(input2.raw_data()),
57 "Argument `axis` is missing and was not specified as input.");
62 "Argument `tiles` is missing and was not specified as input.");
65 "Argument `axis` is missing and was not specified as input.");
68 tiles_ = temp_params[0];
69 axis_ = temp_params[1];
71 auto* output = Output(0);
72 const auto axis = input.canonical_axis_index(axis_);
75 vector<TIndex> output_dims(input.dims());
76 output_dims[axis_] = output_dims[axis_] * tiles_;
77 output->Resize(output_dims);
80 const auto outer_dim = input.size_to_dim(axis);
82 const auto inner_dim = input.size_from_dim(axis);
92 const char* input_data =
static_cast<const char*
>(input.raw_data());
94 static_cast<char*
>(output->raw_mutable_data(input.meta()));
113 const char* input_data,
115 for (
auto i = 0; i < outer_dim; ++i) {
116 for (
auto t = 0; t < tiles_; ++t) {
117 context_.template CopyItems<Context, Context>(
118 meta, inner_dim, input_data, output_data);
119 output_data += inner_dim * item_size;
121 input_data += inner_dim * item_size;
129 template <
typename T,
class Context>
132 USE_OPERATOR_CONTEXT_FUNCTIONS;
135 tiles_(OperatorBase::GetSingleArgument<int32_t>(
"tiles", 1)),
136 axis_(OperatorBase::GetSingleArgument<int32_t>(
"axis", 0)) {}
140 std::array<int32_t, 2> temp_params = {{tiles_, axis_}};
141 if (InputSize() > 1) {
148 Input(1).ndim() == 1 && Input(1).size() == 1,
149 "Input `tiles` should be a vector of size 1.");
151 const auto& input1 = Input(1);
152 context_.template CopyItems<Context, CPUContext>(
155 static_cast<const char*
>(input1.raw_data()),
158 if (InputSize() > 2) {
160 Input(2).ndim() == 1 && Input(2).size() == 1,
161 "Input `axis` should be a vector of size 1.");
163 const auto& input2 = Input(2);
164 context_.template CopyItems<Context, CPUContext>(
167 static_cast<const char*
>(input2.raw_data()),
172 "Argument `axis` is missing and was not specified as input.");
177 "Argument `tiles` is missing and was not specified as input.");
180 "Argument `axis` is missing and was not specified as input.");
183 tiles_ = temp_params[0];
184 axis_ = temp_params[1];
186 const auto& input = Input(0);
187 auto* output = Output(0);
188 const auto axis = input.canonical_axis_index(axis_);
191 vector<TIndex> output_dims(input.dims());
192 output_dims[axis_] = output_dims[axis_] / tiles_;
193 output->Resize(output_dims);
196 const auto outer_dim = output->size_to_dim(axis);
198 const auto inner_dim = output->size_from_dim(axis);
210 const char* input_data =
static_cast<const char*
>(input.raw_data());
212 static_cast<char*
>(output->raw_mutable_data(input.meta()));
231 const char* input_data,
233 for (
auto i = 0; i < outer_dim; ++i) {
234 context_.template CopyItems<Context, Context>(
235 meta, inner_dim, input_data, output_data);
236 input_data += inner_dim * item_size;
237 for (
auto t = 1; t < tiles_; ++t) {
238 math::Axpy<T, Context>(
241 reinterpret_cast<const T*
>(input_data),
242 reinterpret_cast<T*>(output_data),
244 input_data += inner_dim * item_size;
246 output_data += inner_dim * item_size;
256 #endif // CAFFE2_OPERATORS_TILE_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
bool RunOnDevice() override
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
bool RunOnDevice() override