Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_batch_norm_op.cc
1 #include "caffe2/mobile/contrib/arm-compute/core/context.h"
2 #include "caffe2/mobile/contrib/arm-compute/core/operator.h"
3 
4 #include "caffe2/operators/spatial_batch_norm_op.h"
5 
6 namespace caffe2 {
7 
8 template <typename T> class GLSpatialBNOp final : public Operator<GLContext> {
9 public:
10  GLSpatialBNOp(const OperatorDef &operator_def, Workspace *ws)
11  : Operator<GLContext>(operator_def, ws),
12  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
13  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5)),
14  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.9)),
15  order_(StringToStorageOrder(
16  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) { }
17  virtual ~GLSpatialBNOp() noexcept {}
18  USE_OPERATOR_FUNCTIONS(GLContext);
19  bool RunOnDevice() override;
20  protected:
21  bool is_test_;
22  double epsilon_;
23  double momentum_;
24  StorageOrder order_;
25  INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR);
26  OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_VAR);
27 private:
28  arm_compute::GCBatchNormalizationLayer bn_layer_;
29  bool first_run_ = true, second_run_ = true;
30  GLContext::deleted_unique_ptr<const GLTensor<T>> X_, mean_, var_, bias_, scale_;
31 };
32 
33 template <typename T>
35  auto *XBlob = OperatorBase::Inputs()[0];
36  auto *scaleBlob = OperatorBase::Inputs()[SCALE];
37  auto *biasBlob = OperatorBase::Inputs()[BIAS];
38  auto *meanBlob = OperatorBase::Inputs()[EST_MEAN];
39  auto *varBlob = OperatorBase::Inputs()[EST_VAR];
40 
41  if (first_run_) {
42  X_ = GLContext::getGLTensor<T>(XBlob);
43  scale_ = GLContext::getGLTensor<T>(scaleBlob);
44  bias_ = GLContext::getGLTensor<T>(biasBlob);
45  mean_ = GLContext::getGLTensor<T>(meanBlob);
46  var_ = GLContext::getGLTensor<T>(varBlob);
47  }
48 
49  auto C = X_->dim32(1);
50  CAFFE_ENFORCE_EQ(scale_->ndim(), 1);
51  CAFFE_ENFORCE_EQ(bias_->ndim(), 1);
52  CAFFE_ENFORCE_EQ(mean_->ndim(), 1);
53  CAFFE_ENFORCE_EQ(var_->ndim(), 1);
54 
55  CAFFE_ENFORCE_EQ(scale_->dim32(0), C);
56  CAFFE_ENFORCE_EQ(bias_->dim32(0), C);
57  CAFFE_ENFORCE_EQ(mean_->dim32(0), C);
58  CAFFE_ENFORCE_EQ(var_->dim32(0), C);
59 
60  GLTensor<T> *Y =
61  OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>();
62  if (first_run_) {
63  first_run_ = false;
64  Y->ResizeLike(*X_);
65  bn_layer_.configure(X_->get_underlying(), Y->get_underlying(),
66  mean_->get_underlying(), var_->get_underlying(),
67  bias_->get_underlying(), scale_->get_underlying(), epsilon_);
68  } else {
69  X_->lazy_allocate(XBlob, second_run_, true);
70  scale_->lazy_allocate(scaleBlob, second_run_, second_run_);
71  bias_->lazy_allocate(biasBlob, second_run_, second_run_);
72  mean_->lazy_allocate(meanBlob, second_run_, second_run_);
73  var_->lazy_allocate(varBlob, second_run_, second_run_);
74  if (second_run_) {
75  second_run_ = false;
76  Y->allocate();
77  }
78  bn_layer_.run();
79  }
80  return true;
81 }
82 
83 REGISTER_GL_OPERATOR(SpatialBN, GLSpatialBNOp<DataType>);
84 
85 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...