1 #ifndef CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_ 2 #define CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 "Argument `old_shape` is missing.");
21 "Argument `new_shape` is missing.");
23 vector<TIndex> old_shape =
24 OperatorBase::GetRepeatedArgument<TIndex>(
"old_shape");
25 vector<TIndex> new_shape =
26 OperatorBase::GetRepeatedArgument<TIndex>(
"new_shape");
29 old_shape.size() == 2,
30 "Argument `old_shape` must contain exactly two integers.");
32 new_shape.size() == 2,
33 "Argument `new_shape` must contain exactly two integers.");
37 "The second dimension in argument `old_shape` must be positive.");
39 old_stride_ = old_shape[1];
41 if (old_shape[0] == -1) {
44 "The second dimension in `new_shape` must be positive.");
48 "The first dimension in `old_shape` must be positive.");
50 TIndex matrix_size = old_shape[0] * old_shape[1];
52 if (new_shape[0] == -1) {
55 "Only one dimension in argument `new_shape` can be -1.");
57 matrix_size % new_shape[1] == 0,
58 "Argument `new_shape` does not agree with `old_shape`.");
61 new_shape[0] > 0 && (new_shape[1] == -1 || new_shape[1] > 0),
62 "Dimensions in argument `new_shape` must be positive or -1.");
63 if (new_shape[1] == -1) {
65 matrix_size % new_shape[0] == 0,
66 "Argument `new_shape` does not agree with `old_shape`.");
67 new_shape[1] = matrix_size / new_shape[0];
70 new_shape[0] * new_shape[1] == matrix_size,
71 "Argument `new_shape` does not agree with `old_shape`.");
75 new_stride_ = new_shape[1];
78 bool RunOnDevice()
override {
79 auto& old_col = Input(0);
80 CAFFE_ENFORCE(old_col.ndim() == 1,
"Row index tensor must be 1-D.");
81 auto& old_row = Input(1);
82 CAFFE_ENFORCE(old_row.ndim() == 1,
"Column index tensor must be 1-D.");
84 const auto nnz = old_col.size();
86 old_row.size() == nnz,
87 "Column and row tensors must have the same size.");
89 auto* new_col = Output(0);
90 auto* new_row = Output(1);
94 const auto* old_col_data = old_col.template data<TIndex>();
95 const auto* old_row_data = old_row.template data<int>();
97 auto* new_col_data = new_col->template mutable_data<TIndex>();
98 auto* new_row_data = new_row->template mutable_data<int>();
100 for (
int i = 0; i < nnz; ++i) {
101 TIndex offset = old_row_data[i] * old_stride_ + old_col_data[i];
102 new_row_data[i] = offset / new_stride_;
103 new_col_data[i] = offset % new_stride_;
116 #endif // CAFFE2_OPERATORS_SPARSE_MATRIX_RESHAPE_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.