1 #ifndef CAFFE2_OPERATORS_MEAN_OPS_H_ 2 #define CAFFE2_OPERATORS_MEAN_OPS_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/types.h" 9 #include "caffe2/utils/math.h" 10 #include "caffe2/utils/proto_utils.h" 14 template <
class Context>
17 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 USE_SIMPLE_CTOR_DTOR(
MeanOp)
21 bool DoRunWithType() {
22 auto& input0 = Input(0);
23 auto* output = Output(0);
25 output->ResizeLike(input0);
26 output->CopyFrom(input0, &context_);
28 if (InputSize() == 1) {
33 for (
int i = 1; i < InputSize(); ++i) {
34 if (output->dims() != Input(i).dims()) {
36 "Check failed: output->dims() == Input(i).dims().",
37 "Description: Input #",
41 " should match output dimension: ",
46 T* output_data = output->template mutable_data<T>();
47 for (
int i = 1; i < InputSize(); ++i) {
51 Input(i).template data<T>(),
66 bool RunOnDevice()
override {
67 if (Input(0).
template IsType<float>()) {
68 return DoRunWithType<float>();
71 "Mean operator only supports 32-bit float, but",
72 " input was of type ",
73 Input(0).meta().name());
78 template <
class Context>
81 USE_OPERATOR_CONTEXT_FUNCTIONS;
87 bool DoRunWithType() {
89 const auto* dY_data = dY.template data<T>();
92 int num_inputs = OutputSize();
93 float scale = 1.0f / num_inputs;
96 auto* dX0 = Output(0);
99 size, scale, dY_data, dX0->template mutable_data<T>(), &context_);
102 for (
int i = 1; i < num_inputs; i++) {
103 auto* cur_dX = Output(i);
104 cur_dX->ResizeLike(dY);
105 cur_dX->CopyFrom(*dX0, &context_);
111 bool RunOnDevice()
override {
112 if (Input(0).
template IsType<float>()) {
113 return DoRunWithType<float>();
116 "Mean operator only supports 32-bit float, but",
117 " input was of type ",
118 Input(0).meta().name());
125 #endif // CAFFE2_OPERATORS_MEAN_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...