Caffe2 - C++ API
A deep learning, cross platform ML framework
resize_op.cc
1 #include "caffe2/mobile/contrib/arm-compute/core/context.h"
2 #include "caffe2/mobile/contrib/arm-compute/core/operator.h"
3 #include "caffe2/operators/resize_op.h"
4 
5 namespace caffe2 {
6 
7 template<typename T>
8 class GLResizeNearestOp final : public Operator<GLContext> {
9 public:
10  GLResizeNearestOp(const OperatorDef &operator_def, Workspace *ws)
11  : Operator<GLContext>(operator_def, ws), width_scale_(1), height_scale_(1) {
12  if (HasArgument("width_scale")) {
13  width_scale_ = static_cast<float>(
14  OperatorBase::GetSingleArgument<float>("width_scale", 1));
15  }
16  if (HasArgument("height_scale")) {
17  height_scale_ = static_cast<float>(
18  OperatorBase::GetSingleArgument<float>("height_scale", 1));
19  }
20  CAFFE_ENFORCE_GT(width_scale_, 0);
21  CAFFE_ENFORCE_GT(height_scale_, 0);
22  }
23  virtual ~GLResizeNearestOp() noexcept {}
24  USE_OPERATOR_FUNCTIONS(GLContext);
25  bool RunOnDevice() override;
26 private:
27  float width_scale_;
28  float height_scale_;
29  arm_compute::GCScale resize_layer_;
30  bool first_run_ = true, second_run_ = true;
31  GLContext::deleted_unique_ptr<const GLTensor<T>> X_;
32 };
33 
34 template <typename T>
36 
37  auto* Xblob = OperatorBase::Inputs()[0];
38 
39  if (first_run_) {
40  X_ = GLContext::getGLTensor<T>(Xblob);
41  }
42 
43  auto N = X_->dim32(0);
44  auto C = X_->dim32(1);
45  auto H = X_->dim32(2);
46  auto W = X_->dim32(3);
47 
48  GLTensor<T> *Y =
49  OperatorBase::Outputs()[0]->template GetMutable<GLTensor<T>>();
50  if (first_run_) {
51  vector<TIndex> output_dims = {N, C, H * height_scale_, W * width_scale_};
52  Y->Resize(output_dims);
53  first_run_ = false;
54  resize_layer_.configure(X_->get_underlying(), Y->get_underlying(), arm_compute::InterpolationPolicy::NEAREST_NEIGHBOR, arm_compute::BorderMode::UNDEFINED);
55  } else {
56  X_->lazy_allocate(Xblob, second_run_, true);
57  if (second_run_) {
58  second_run_ = false;
59  Y->allocate();
60  }
61  resize_layer_.run();
62  }
63 
64  return true;
65 }
66 
67 REGISTER_GL_OPERATOR(ResizeNearest, GLResizeNearestOp<DataType>);
68 
69 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37