Caffe2 - C++ API
A deep learning, cross platform ML framework
pow_op.h
1 #ifndef CAFFE2_OPERATORS_POW_OP_H_
2 #define CAFFE2_OPERATORS_POW_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 // definition of NumericTypes and SameTypeAsInput is in below header file
10 #include "caffe2/operators/elementwise_op.h"
11 
12 namespace caffe2 {
13 
14 template <
15  typename InputTypes,
16  class Context,
17  class Functor,
18  class TypeMap = SameTypeAsInput>
19 class PowOp : public Operator<Context> {
20  public:
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22 
23  PowOp(const OperatorDef& operator_def, Workspace* ws)
24  : Operator<Context>(operator_def, ws),
25  OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
26  OP_SINGLE_ARG(int, "axis", axis_, -1),
27  OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
28  OP_SINGLE_ARG(string, "order", order_, "NCHW"),
29  functor_() {
30  if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
31  exponent_ = this->template GetSingleArgument<float>(
32  "exponent", 0); // based on pow_ops.h
33  } else if (InputSize() == 2) { // BinaryElementwiseOp
34  // Figure out the correct axis to use.
35  if (enable_broadcast_) {
36  if (axis_ != -1) {
37  // Get axis from an explicit axis argument.
38  CAFFE_ENFORCE_EQ(
39  axis_str_.size(),
40  0,
41  "Args axis and axis_str cannot be used simultaneously.");
42  } else if (axis_str_.size()) {
43  // Get the axis index semantically.
44  CAFFE_ENFORCE_EQ(
45  axis_str_.size(), 1, "Unsupported axis string", axis_str_);
46  size_t semantic_axis_ = order_.find(axis_str_);
47  CAFFE_ENFORCE_NE(
48  semantic_axis_,
49  string::npos,
50  "Unrecognizable axis string ",
51  axis_str_,
52  " from order string ",
53  order_);
54  axis_ = semantic_axis_;
55  }
56  } else {
57  CAFFE_ENFORCE(
58  axis_ == -1 && axis_str_.size() == 0,
59  "Do not specify axis or axis_str if broadcast is not enabled.");
60  }
61  } else {
62  CAFFE_THROW(
63  "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
64  }
65  }
66 
67  bool RunOnDevice() override {
68  return DispatchHelper<InputTypes>::call(this, Input(0));
69  }
70 
71  template <typename T>
72  bool DoRunWithType() {
73  if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
74  const auto& A = Input(0);
75  auto* C = Output(0);
76  C->ResizeLike(A);
77  const T* Adata = A.template data<T>();
78  auto* Cdata =
79  C->template mutable_data<typename TypeMap::template type<T>>();
80  functor_.template Run<true, T, float, T>(
81  A.size(), Adata, NULL, exponent_, Cdata, &context_);
82  } else if (InputSize() == 2) { // BinaryElementwiseOp
83  const auto& A = Input(0);
84  const auto& B = Input(1);
85  auto* C = Output(0);
86  CAFFE_ENFORCE(
87  &B != C || !enable_broadcast_,
88  "In-place is allowed only with the first tensor when broadcasting");
89  C->ResizeLike(A);
90  const T* Adata = A.template data<T>();
91  const T* Bdata = B.template data<T>();
92  auto* Cdata =
93  C->template mutable_data<typename TypeMap::template type<T>>();
94  if (!enable_broadcast_) {
95  CAFFE_ENFORCE_EQ(
96  A.dims(),
97  B.dims(),
98  "Dimension mismatch - did you forget to set broadcast=1?");
99  functor_.template Run<false, T, T, T>(
100  A.size(), Adata, Bdata, 0, Cdata, &context_);
101  } else if (B.size() == 1) {
102  functor_.template Run<true, T, T, T>(
103  A.size(), Adata, Bdata, 0, Cdata, &context_);
104  } else {
105  size_t pre, n, post;
106  std::tie(pre, n, post) = calculate_broadcast_sizes(A, B, axis_);
107  if (post == 1) {
108  functor_.template RunWithBroadcast<T, T, T>(
109  Adata, Bdata, Cdata, pre, n, &context_);
110  } else {
111  functor_.template RunWithBroadcast2<T, T, T>(
112  Adata, Bdata, Cdata, pre, n, post, &context_);
113  }
114  }
115  } else {
116  CAFFE_THROW(
117  "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
118  }
119  return true;
120  }
121 
122  private:
123  bool enable_broadcast_;
124  int axis_;
125  string axis_str_;
126  string order_;
127  float exponent_;
128  Functor functor_;
129 };
130 
131 } // namespace caffe2
132 
133 #endif // CAFFE2_OPERATORS_POW_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:37