1 #include "caffe2/operators/elementwise_op.h" 7 #define NAIVE_FUNCTOR(name, op, input_type, output_type) \ 8 struct Naive##name##Functor { \ 9 template <int b_is_scalar, typename T, typename R> \ 10 inline void Run(size_t n, const T* a, const T* b, R* out, CPUContext*) { \ 11 for (int i = 0; i < n; ++i) { \ 12 out[i] = op(a[i], b[b_is_scalar ? 0 : i]); \ 15 template <typename T, typename R> \ 16 void RunWithBroadcast( \ 23 for (int i = 0; i < pre; ++i) { \ 24 for (int j = 0; j < n; ++j) { \ 25 out[i * n + j] = op(a[i * n + j], b[j]); \ 29 template <typename T, typename R> \ 30 void RunWithBroadcast2( \ 38 for (int i = 0; i < pre; ++i) { \ 39 for (int j = 0; j < n; ++j) { \ 40 for (int k = 0; k < post; ++k) { \ 41 out[(i * n + j) * post + k] = op(a[(i * n + j) * post + k], b[j]); \ 47 REGISTER_CPU_OPERATOR( \ 49 BinaryElementwiseOp< \ 52 Naive##name##Functor, \ 55 #define NAIVE_LT(x, y) ((x) < (y)) 56 NAIVE_FUNCTOR(LT, NAIVE_LT, NumericTypes, FixedType<bool>);
58 #define NAIVE_LE(x, y) ((x) <= (y)) 59 NAIVE_FUNCTOR(LE, NAIVE_LE, NumericTypes, FixedType<bool>);
61 #define NAIVE_GT(x, y) ((x) > (y)) 62 NAIVE_FUNCTOR(GT, NAIVE_GT, NumericTypes, FixedType<bool>);
64 #define NAIVE_GE(x, y) ((x) >= (y)) 65 NAIVE_FUNCTOR(GE, NAIVE_GE, NumericTypes, FixedType<bool>);
67 #define NAIVE_EQ(x, y) ((x) == (y)) 68 NAIVE_FUNCTOR(EQ, NAIVE_EQ, IntBoolTypes, FixedType<bool>);
70 #define NAIVE_AND(x, y) ((x) & (y)) 71 NAIVE_FUNCTOR(And, NAIVE_AND, BoolTypes, FixedType<bool>);
73 #define NAIVE_OR(x, y) ((x) | (y)) 74 NAIVE_FUNCTOR(Or, NAIVE_OR, BoolTypes, FixedType<bool>);
76 #define NAIVE_XOR(x, y) ((x) ^ (y)) 77 NAIVE_FUNCTOR(Xor, NAIVE_XOR, BoolTypes, FixedType<bool>);
81 inline void operator()(
const int n,
const bool* x,
bool* y,
CPUContext*) {
82 for (
int i = 0; i < n; ++i) {
87 REGISTER_CPU_OPERATOR(
92 void SRLHelper::sum2one(
const T* x, T* y,
size_t n) {
93 *y = ConstEigenArrayMap<T>(x, n, 1).sum();
97 void SRLHelper::RunWithBroadcastFront(
103 EigenArrayMap<T>(y, n, 1) = ConstEigenArrayMap<T>(x, n, pre).rowwise().sum();
106 template <
typename T>
107 void SRLHelper::RunWithBroadcastBack(
113 EigenArrayMap<T>(y, 1, n) = ConstEigenArrayMap<T>(x, post, n).colwise().sum();
116 template <
typename T>
117 void SRLHelper::RunWithBroadcast2(
124 for (
int i = 0; i < n; ++i) {
126 for (
int j = 0; j < pre; ++j) {
127 for (
int k = 0; k < post; ++k) {
128 y[i] += a[(j * n + i) * post + k];
135 template <
typename T>
137 const auto& A = Input(0);
138 const auto& B = Input(1);
140 CAFFE_ENFORCE(&B != C,
"In-place is not allowed.");
142 const T* Adata = A.template data<T>();
143 auto* Cdata = C->template mutable_data<T>();
145 auto count = A.size();
146 SRLHelper::sum2one<T>(Adata, Cdata, count);
149 std::tie(pre, n, post) = calculate_broadcast_sizes(A, B, axis_);
151 SRLHelper::RunWithBroadcastFront<T>(Adata, Cdata, pre, n, &context_);
152 }
else if (pre == 1) {
153 SRLHelper::RunWithBroadcastBack<T>(Adata, Cdata, post, n, &context_);
155 SRLHelper::RunWithBroadcast2<T>(Adata, Cdata, pre, n, post, &context_);
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...