1 #ifndef CAFFE2_OPERATORS_RESHAPE_OP_H_ 2 #define CAFFE2_OPERATORS_RESHAPE_OP_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 13 template <
typename F,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 new_shape_(OperatorBase::GetRepeatedArgument<int64_t>(
"shape")) {}
21 bool RunOnDevice()
override {
22 if (InputSize() == 2) {
27 return this->
template DoRunWithType<int64_t>();
31 bool DoRunWithType() {
32 DoRunWithTypeImpl<T>(Input(0), Output(0));
38 void DoRunWithTypeImpl(
41 vector<int64_t> actual_new_shape = new_shape_;
42 if (InputSize() == 2) {
45 "New shape is specified by the input blob, do not pass in " 46 "the argument `shape`.");
48 auto& shape = Input(1);
49 CAFFE_ENFORCE(shape.ndim() == 1,
"Shape should be 1-D");
51 const T* shape_data = shape.template data<T>();
54 std::vector<T> tmpv(shape.size());
55 context_.template CopyBytes<Context, CPUContext>(
56 shape.size() *
sizeof(T), shape_data, &tmpv[0]);
57 actual_new_shape.assign(tmpv.begin(), tmpv.begin() + shape.size());
61 for (
int i = 0; i < actual_new_shape.size(); ++i) {
62 if (actual_new_shape[i] == 0) {
63 actual_new_shape[i] = input.
dim(i);
70 auto total_size = input.size_from_dim(0);
73 for (
int i = 0; i < actual_new_shape.size(); ++i) {
74 const auto dim = actual_new_shape[i];
78 "Argument `shape` has more than one missing dimension.");
85 if (unknown_idx != -1) {
87 total_size % size == 0,
88 "Argument `shape` does not agree with the input data.",
94 actual_new_shape[unknown_idx] = total_size / size;
99 "Argument `shape` does not agree with the input data.",
108 auto* old_shape = Output(1);
109 old_shape->Resize(input.
ndim());
110 T* old_shape_data = old_shape->template mutable_data<T>();
111 for (
int i = 0; i < input.
ndim(); ++i) {
112 math::Set<T, Context>(1, input.
dim(i), old_shape_data + i, &context_);
115 output->
Resize(actual_new_shape);
116 if (output != &input) {
118 context_.template CopyItems<Context, Context>(
127 vector<int64_t> new_shape_;
132 #endif // CAFFE2_OPERATORS_RESHAPE_OP_H_ const TypeMeta & meta() const
Returns the TypeMeta object associated with the current data type.
TIndex dim(const int i) const
Returns the i-th dimension of the tensor.
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
TIndex size() const
Returns the size (i.e.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Resize(Ts...dim_source)
Resizes a tensor.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
int ndim() const
Returns the number of dimensions of the data.
void * raw_mutable_data(const TypeMeta &meta)
Returns a mutable raw pointer of the underlying storage.