1 #include "caffe2/mobile/contrib/arm-compute/core/context.h" 2 #include "caffe2/mobile/contrib/arm-compute/core/operator.h" 4 #include "caffe2/operators/spatial_batch_norm_op.h" 12 is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
13 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5)),
14 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.9)),
15 order_(StringToStorageOrder(
16 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))) { }
19 bool RunOnDevice()
override;
25 INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR);
26 OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_VAR);
28 arm_compute::GCBatchNormalizationLayer bn_layer_;
29 bool first_run_ =
true, second_run_ =
true;
30 GLContext::deleted_unique_ptr<const GLTensor<T>> X_, mean_, var_, bias_, scale_;
35 auto *XBlob = OperatorBase::Inputs()[0];
36 auto *scaleBlob = OperatorBase::Inputs()[SCALE];
37 auto *biasBlob = OperatorBase::Inputs()[BIAS];
38 auto *meanBlob = OperatorBase::Inputs()[EST_MEAN];
39 auto *varBlob = OperatorBase::Inputs()[EST_VAR];
42 X_ = GLContext::getGLTensor<T>(XBlob);
43 scale_ = GLContext::getGLTensor<T>(scaleBlob);
44 bias_ = GLContext::getGLTensor<T>(biasBlob);
45 mean_ = GLContext::getGLTensor<T>(meanBlob);
46 var_ = GLContext::getGLTensor<T>(varBlob);
49 auto C = X_->dim32(1);
50 CAFFE_ENFORCE_EQ(scale_->ndim(), 1);
51 CAFFE_ENFORCE_EQ(bias_->ndim(), 1);
52 CAFFE_ENFORCE_EQ(mean_->ndim(), 1);
53 CAFFE_ENFORCE_EQ(var_->ndim(), 1);
55 CAFFE_ENFORCE_EQ(scale_->dim32(0), C);
56 CAFFE_ENFORCE_EQ(bias_->dim32(0), C);
57 CAFFE_ENFORCE_EQ(mean_->dim32(0), C);
58 CAFFE_ENFORCE_EQ(var_->dim32(0), C);
61 OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>();
65 bn_layer_.configure(X_->get_underlying(), Y->get_underlying(),
66 mean_->get_underlying(), var_->get_underlying(),
67 bias_->get_underlying(), scale_->get_underlying(), epsilon_);
69 X_->lazy_allocate(XBlob, second_run_,
true);
70 scale_->lazy_allocate(scaleBlob, second_run_, second_run_);
71 bias_->lazy_allocate(biasBlob, second_run_, second_run_);
72 mean_->lazy_allocate(meanBlob, second_run_, second_run_);
73 var_->lazy_allocate(varBlob, second_run_, second_run_);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...