2 #ifndef CAFFE2_OPERATORS_PREPEND_DIM_OP_H_ 3 #define CAFFE2_OPERATORS_PREPEND_DIM_OP_H_ 5 #include "caffe2/core/common_omp.h" 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 dim_size_(OperatorBase::GetSingleArgument<int64_t>(
"dim_size", 0)) {
20 dim_size_, 0,
"Argument dim_size must be greater than zero.");
23 bool RunOnDevice()
override {
24 auto& input = Input(0);
25 auto* output = Output(0);
27 CAFFE_ENFORCE(input.ndim() > 0,
"Input must be at least 1D.");
29 input.dim(0) % dim_size_ == 0,
30 "First dimension must be multiple of prepend_dim.");
32 vector<int64_t> actual_new_shape(input.ndim() + 1);
33 actual_new_shape[0] = dim_size_;
34 actual_new_shape[1] = input.dim(0) / dim_size_;
35 for (
int i = 1; i < input.dims().size(); ++i) {
36 actual_new_shape[i + 1] = input.dim(i);
38 output->Resize(actual_new_shape);
40 if (output != &input) {
42 context_.template CopyItems<Context, Context>(
46 output->raw_mutable_data(input.meta()));
55 template <
class Context>
58 USE_OPERATOR_CONTEXT_FUNCTIONS;
62 bool RunOnDevice()
override {
63 auto& input = Input(0);
64 auto* output = Output(0);
66 CAFFE_ENFORCE(input.ndim() > 1,
"Input must be at least 2D.");
68 vector<int64_t> actual_new_shape(input.ndim() - 1);
69 actual_new_shape[0] = input.dim(0) * input.dim(1);
70 for (
int i = 1; i < input.dims().size() - 1; ++i) {
71 actual_new_shape[i] = input.dim(i + 1);
73 output->Resize(actual_new_shape);
75 if (output != &input) {
77 context_.template CopyItems<Context, Context>(
81 output->raw_mutable_data(input.meta()));
92 #endif // CAFFE2_OPERATORS_PREPEND_DIM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...