1 #ifndef CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_ 2 #define CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/operators/elementwise_op.h" 10 #include <unordered_set> 14 template <
class Context>
17 USE_OPERATOR_FUNCTIONS(Context);
22 OP_SINGLE_ARG(
bool,
"broadcast_on_rows", enable_broadcast_, 0) {}
24 bool RunOnDevice()
override {
31 bool DoRunWithType() {
32 auto& select = Input(0);
33 auto& left = Input(1);
34 auto& right = Input(2);
35 auto* output = Output(0);
36 if (enable_broadcast_) {
37 CAFFE_ENFORCE_EQ(select.ndim(), 1);
38 CAFFE_ENFORCE_EQ(select.dim(0), right.dim(0));
39 CAFFE_ENFORCE_EQ(left.dims(), right.dims());
41 CAFFE_ENFORCE_EQ(select.dims(), left.dims());
42 CAFFE_ENFORCE_EQ(select.dims(), right.dims());
44 output->ResizeLike(left);
46 const bool* select_data = select.template data<bool>();
47 const T* left_data = left.template data<T>();
48 const T* right_data = right.template data<T>();
49 T* output_data = output->template mutable_data<T>();
51 if (enable_broadcast_) {
52 size_t block_size = left.size_from_dim(1);
53 for (
int i = 0; i < select.size(); i++) {
54 size_t offset = i * block_size;
56 context_.template CopyItems<Context, Context>(
60 output_data + offset);
62 context_.template CopyItems<Context, Context>(
66 output_data + offset);
70 for (
int i = 0; i < select.size(); ++i) {
71 output_data[i] = select_data[i] ? left_data[i] : right_data[i];
78 bool enable_broadcast_;
82 std::unordered_set<int32_t> int32_values_;
83 std::unordered_set<int64_t> int64_values_;
84 std::unordered_set<bool> bool_values_;
85 std::unordered_set<std::string> string_values_;
86 bool has_values_ =
false;
90 std::unordered_set<T>&
get();
93 void set(
const std::vector<T>& args) {
95 auto& values = get<T>();
96 values.insert(args.begin(), args.end());
104 template <
class Context>
106 USE_OPERATOR_CONTEXT_FUNCTIONS;
109 static constexpr
const char* VALUE_TAG =
"value";
117 static_cast<TensorProto_DataType
>(OperatorBase::GetSingleArgument<int>(
118 "dtype", TensorProto_DataType_UNDEFINED));
120 case TensorProto_DataType_INT32:
121 values_.set(OperatorBase::GetRepeatedArgument<int32_t>(VALUE_TAG));
123 case TensorProto_DataType_INT64:
124 values_.set(OperatorBase::GetRepeatedArgument<int64_t>(VALUE_TAG));
126 case TensorProto_DataType_BOOL:
127 values_.set(OperatorBase::GetRepeatedArgument<bool>(VALUE_TAG));
129 case TensorProto_DataType_STRING:
130 values_.set(OperatorBase::GetRepeatedArgument<std::string>(VALUE_TAG));
132 case TensorProto_DataType_UNDEFINED:
137 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
142 bool RunOnDevice()
override {
147 template <
typename T>
148 bool DoRunWithType() {
149 auto& input = Input(0);
150 auto* output = Output(0);
151 output->ResizeLike(input);
153 if (!values_.has_values()) {
154 values_.set(OperatorBase::GetRepeatedArgument<T>(VALUE_TAG));
156 const auto& values = values_.get<T>();
158 const T* input_data = input.template data<T>();
159 bool* output_data = output->template mutable_data<bool>();
160 for (
int i = 0; i < input.size(); ++i) {
161 output_data[i] = values.find(input_data[i]) != values.end();
172 #endif // CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...