1 #ifndef CAFFE2_OPERATORS_REDUCE_OPS_H_ 2 #define CAFFE2_OPERATORS_REDUCE_OPS_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/types.h" 8 #include "caffe2/utils/math.h" 9 #include "caffe2/utils/proto_utils.h" 13 template <
typename T,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
20 axes_ = OperatorBase::GetRepeatedArgument<int>(
"axes");
21 keepdims_ = OperatorBase::GetSingleArgument<int>(
"keepdims", 1);
24 bool RunOnDevice()
override {
25 int ndim = Input(0).ndim();
29 std::iota(axes_.begin(), axes_.end(), 0);
31 std::sort(axes_.begin(), axes_.end());
32 CAFFE_ENFORCE(axes_.front() >= 0,
"Axes ids must be non-negative.");
35 "Axes ids must be smaller than the dimensions of input.");
41 vector<TIndex> y_dims = X.dims();
42 TIndex Y_size = X.size();
43 for (TIndex
id = axes_.size() - 1;
id >= 0;
id--) {
44 TIndex reduced_axis = axes_[id];
45 Y_size /= y_dims[reduced_axis];
47 y_dims[reduced_axis] = 1;
49 y_dims.erase(y_dims.begin() + reduced_axis);
57 const_cast<vector<TIndex>&
>(X.dims()),
58 Y->template mutable_data<T>(),
73 vector<TIndex>& Y_dims,
77 std::vector<int> axes_;
81 template <
typename T,
class Context>
84 USE_OPERATOR_CONTEXT_FUNCTIONS;
97 vector<TIndex>& Y_dims,
98 int keepdims)
override;
101 template <
typename T,
class Context>
104 USE_OPERATOR_CONTEXT_FUNCTIONS;
113 vector<TIndex>& dims,
117 vector<TIndex>& Y_dims,
118 int keepdims)
override;
123 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...