Caffe2 - C++ API
A deep learning, cross platform ML framework
resize_op.h
1 
2 #pragma once
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class ResizeNearestOp final : public Operator<Context> {
11  public:
12  ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
13  : Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
14  if (HasArgument("width_scale")) {
15  width_scale_ = static_cast<T>(
16  OperatorBase::GetSingleArgument<float>("width_scale", 1));
17  }
18  if (HasArgument("height_scale")) {
19  height_scale_ = static_cast<T>(
20  OperatorBase::GetSingleArgument<float>("height_scale", 1));
21  }
22  CAFFE_ENFORCE_GT(width_scale_, 0);
23  CAFFE_ENFORCE_GT(height_scale_, 0);
24  }
25  USE_OPERATOR_CONTEXT_FUNCTIONS;
26 
27  bool RunOnDevice() override;
28 
29  protected:
30  T width_scale_;
31  T height_scale_;
32 };
33 
34 template <typename T, class Context>
35 class ResizeNearestGradientOp final : public Operator<Context> {
36  public:
37  ResizeNearestGradientOp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<Context>(operator_def, ws), width_scale_(1), height_scale_(1) {
39  width_scale_ = static_cast<T>(
40  OperatorBase::GetSingleArgument<float>("width_scale", 1));
41  height_scale_ = static_cast<T>(
42  OperatorBase::GetSingleArgument<float>("height_scale", 1));
43  CAFFE_ENFORCE_GT(width_scale_, 0);
44  CAFFE_ENFORCE_GT(height_scale_, 0);
45  }
46  USE_OPERATOR_CONTEXT_FUNCTIONS;
47 
48  bool RunOnDevice() override;
49 
50  protected:
51  T width_scale_;
52  T height_scale_;
53 };
54 
55 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37