1 #ifndef CAFFE2_OPERATORS_POW_OP_H_ 2 #define CAFFE2_OPERATORS_POW_OP_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 10 #include "caffe2/operators/elementwise_op.h" 18 class TypeMap = SameTypeAsInput>
21 USE_OPERATOR_CONTEXT_FUNCTIONS;
25 OP_SINGLE_ARG(
bool,
"broadcast", enable_broadcast_, 0),
26 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
27 OP_SINGLE_ARG(
string,
"axis_str", axis_str_,
""),
28 OP_SINGLE_ARG(
string,
"order", order_,
"NCHW"),
30 if ((InputSize() == 1) &&
HasArgument(
"exponent")) {
31 exponent_ = this->
template GetSingleArgument<float>(
33 }
else if (InputSize() == 2) {
35 if (enable_broadcast_) {
41 "Args axis and axis_str cannot be used simultaneously.");
42 }
else if (axis_str_.size()) {
45 axis_str_.size(), 1,
"Unsupported axis string", axis_str_);
46 size_t semantic_axis_ = order_.find(axis_str_);
50 "Unrecognizable axis string ",
52 " from order string ",
54 axis_ = semantic_axis_;
58 axis_ == -1 && axis_str_.size() == 0,
59 "Do not specify axis or axis_str if broadcast is not enabled.");
63 "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
67 bool RunOnDevice()
override {
72 bool DoRunWithType() {
73 if ((InputSize() == 1) &&
HasArgument(
"exponent")) {
74 const auto& A = Input(0);
77 const T* Adata = A.template data<T>();
79 C->template mutable_data<typename TypeMap::template type<T>>();
80 functor_.template Run<true, T, float, T>(
81 A.size(), Adata, NULL, exponent_, Cdata, &context_);
82 }
else if (InputSize() == 2) {
83 const auto& A = Input(0);
84 const auto& B = Input(1);
87 &B != C || !enable_broadcast_,
88 "In-place is allowed only with the first tensor when broadcasting");
90 const T* Adata = A.template data<T>();
91 const T* Bdata = B.template data<T>();
93 C->template mutable_data<typename TypeMap::template type<T>>();
94 if (!enable_broadcast_) {
98 "Dimension mismatch - did you forget to set broadcast=1?");
99 functor_.template Run<false, T, T, T>(
100 A.size(), Adata, Bdata, 0, Cdata, &context_);
101 }
else if (B.size() == 1) {
102 functor_.template Run<true, T, T, T>(
103 A.size(), Adata, Bdata, 0, Cdata, &context_);
106 std::tie(pre, n, post) = calculate_broadcast_sizes(A, B, axis_);
108 functor_.template RunWithBroadcast<T, T, T>(
109 Adata, Bdata, Cdata, pre, n, &context_);
111 functor_.template RunWithBroadcast2<T, T, T>(
112 Adata, Bdata, Cdata, pre, n, post, &context_);
117 "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
123 bool enable_broadcast_;
133 #endif // CAFFE2_OPERATORS_POW_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.