1 #ifndef CAFFE2_OPERATORS_ARG_OPS_H_ 2 #define CAFFE2_OPERATORS_ARG_OPS_H_ 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/types.h" 14 template <
typename T,
class Context>
17 USE_OPERATOR_CONTEXT_FUNCTIONS;
21 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
22 OP_SINGLE_ARG(
bool,
"keepdims", keep_dims_,
true) {}
24 bool RunOnDevice()
override {
25 const auto& X = Input(0);
27 const int ndim = X.ndim();
31 CAFFE_ENFORCE_GE(axis_, 0);
32 CAFFE_ENFORCE_LT(axis_, ndim);
33 const std::vector<TIndex>& X_dims = X.dims();
34 std::vector<TIndex> Y_dims;
38 for (
int i = 0; i < axis_; ++i) {
39 Y_dims.push_back(X_dims[i]);
40 prev_size *= X_dims[i];
45 for (
int i = axis_ + 1; i < ndim; ++i) {
46 Y_dims.push_back(X_dims[i]);
47 next_size *= X_dims[i];
50 const TIndex n = X_dims[axis_];
56 Y->template mutable_data<TIndex>());
62 const TIndex prev_size,
63 const TIndex next_size,
69 const bool keep_dims_;
72 template <
typename T,
class Context>
75 USE_OPERATOR_CONTEXT_FUNCTIONS;
83 const TIndex prev_size,
84 const TIndex next_size,
89 template <
typename T,
class Context>
92 USE_OPERATOR_CONTEXT_FUNCTIONS;
100 const TIndex prev_size,
101 const TIndex next_size,
108 #endif // CAFFE2_OPERATORS_ARG_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...