1 #ifndef CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_ 2 #define CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
17 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)),
18 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.9f)),
19 order_(StringToStorageOrder(
20 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))),
21 num_batches_(OperatorBase::GetSingleArgument<int>(
"num_batches", 1)) {
24 (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 5));
25 CAFFE_ENFORCE_GT(epsilon_, 0);
26 CAFFE_ENFORCE_GE(momentum_, 0);
27 CAFFE_ENFORCE_LE(momentum_, 1);
31 bool RunOnDevice()
override {
41 INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR, SUMS, SUMSQ);
42 OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_VAR);
45 template <
class Context>
48 USE_OPERATOR_CONTEXT_FUNCTIONS;
51 is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
52 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5f)),
53 order_(StringToStorageOrder(
54 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))),
55 num_batches_(OperatorBase::GetSingleArgument<int>(
"num_batches", 1)) {
56 CAFFE_ENFORCE(InputSize() == 5 || InputSize() == 7);
57 CAFFE_ENFORCE(OutputSize() == 3);
61 bool RunOnDevice()
override {
79 OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
84 #endif // CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...