Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op.cc
1 #include "caffe2/mobile/contrib/arm-compute/core/context.h"
2 #include "caffe2/mobile/contrib/arm-compute/core/operator.h"
3 
4 #include "caffe2/operators/fully_connected_op.h"
5 
6 namespace caffe2 {
7 
8 template <typename T> class GLFullyConnectedOp final : public Operator<GLContext> {
9 public:
10  GLFullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
11  : Operator<GLContext>(operator_def, ws) {}
12  virtual ~GLFullyConnectedOp() noexcept {}
13  USE_OPERATOR_FUNCTIONS(GLContext);
14  bool RunOnDevice() override;
15 private:
16  arm_compute::GCFullyConnectedLayer fc_layer_;
17  bool first_run_ = true, second_run_ = true;
18  GLContext::deleted_unique_ptr<const GLTensor<T>> X_, W_, B_;
19 };
20 
21 template <typename T>
23 
24  auto Xblob = OperatorBase::Inputs()[0];
25  auto *Wblob = OperatorBase::Inputs()[1];
26  auto *Bblob = OperatorBase::Inputs()[2];
27 
28  if (first_run_) {
29  X_ = GLContext::getGLTensor<T>(Xblob);
30  W_ = GLContext::getGLTensor<T>(Wblob);
31  B_ = GLContext::getGLTensor<T>(Bblob);
32  }
33 
34  auto M = X_->dim32(0);
35  auto CIn = X_->dim32(1);
36  auto Height = X_->dim32(2);
37  auto Width = X_->dim32(3);
38  auto N = W_->dim32(0);
39 
40  CAFFE_ENFORCE_EQ(1, B_->ndim());
41  CAFFE_ENFORCE_EQ(N, B_->dim32(0));
42 
43  vector<TIndex> output_dims = {M, N};
44  GLTensor<T> *Y =
45  OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>();
46  if (first_run_) {
47  first_run_ = false;
48  Y->Resize(output_dims);
49 
50  fc_layer_.configure(X_->get_underlying(), W_->get_underlying(),
51  B_->get_underlying(), Y->get_underlying(), true, false);
52  } else {
53  X_->lazy_allocate(Xblob, second_run_, true);
54  W_->lazy_allocate(Wblob, second_run_, second_run_);
55  B_->lazy_allocate(Bblob, second_run_, second_run_);
56  if (second_run_) {
57  second_run_ = false;
58  Y->allocate();
59  }
60  fc_layer_.run();
61  }
62 
63  return true;
64 }
65 
66 REGISTER_GL_OPERATOR(FC, GLFullyConnectedOp<DataType>);
67 
68 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...