Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_logical_ops.h
1 #ifndef CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
2 #define CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/operators/elementwise_op.h"
9 
10 #include <unordered_set>
11 
12 namespace caffe2 {
13 
14 template <class Context>
15 class WhereOp final : public Operator<Context> {
16  public:
17  USE_OPERATOR_FUNCTIONS(Context);
18  USE_DISPATCH_HELPER;
19 
20  WhereOp(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws),
22  OP_SINGLE_ARG(bool, "broadcast_on_rows", enable_broadcast_, 0) {}
23 
24  bool RunOnDevice() override {
25  return DispatchHelper<
27  call(this, Input(1));
28  }
29 
30  template <typename T>
31  bool DoRunWithType() {
32  auto& select = Input(0);
33  auto& left = Input(1);
34  auto& right = Input(2);
35  auto* output = Output(0);
36  if (enable_broadcast_) {
37  CAFFE_ENFORCE_EQ(select.ndim(), 1);
38  CAFFE_ENFORCE_EQ(select.dim(0), right.dim(0));
39  CAFFE_ENFORCE_EQ(left.dims(), right.dims());
40  } else {
41  CAFFE_ENFORCE_EQ(select.dims(), left.dims());
42  CAFFE_ENFORCE_EQ(select.dims(), right.dims());
43  }
44  output->ResizeLike(left);
45 
46  const bool* select_data = select.template data<bool>();
47  const T* left_data = left.template data<T>();
48  const T* right_data = right.template data<T>();
49  T* output_data = output->template mutable_data<T>();
50 
51  if (enable_broadcast_) {
52  size_t block_size = left.size_from_dim(1);
53  for (int i = 0; i < select.size(); i++) {
54  size_t offset = i * block_size;
55  if (select_data[i]) {
56  context_.template CopyItems<Context, Context>(
57  output->meta(),
58  block_size,
59  left_data + offset,
60  output_data + offset);
61  } else {
62  context_.template CopyItems<Context, Context>(
63  output->meta(),
64  block_size,
65  right_data + offset,
66  output_data + offset);
67  }
68  }
69  } else {
70  for (int i = 0; i < select.size(); ++i) {
71  output_data[i] = select_data[i] ? left_data[i] : right_data[i];
72  }
73  }
74  return true;
75  }
76 
77  private:
78  bool enable_broadcast_;
79 };
80 
82  std::unordered_set<int32_t> int32_values_;
83  std::unordered_set<int64_t> int64_values_;
84  std::unordered_set<bool> bool_values_;
85  std::unordered_set<std::string> string_values_;
86  bool has_values_ = false;
87 
88  public:
89  template <typename T>
90  std::unordered_set<T>& get();
91 
92  template <typename T>
93  void set(const std::vector<T>& args) {
94  has_values_ = true;
95  auto& values = get<T>();
96  values.insert(args.begin(), args.end());
97  }
98 
99  bool has_values() {
100  return has_values_;
101  }
102 };
103 
104 template <class Context>
105 class IsMemberOfOp final : public Operator<Context> {
106  USE_OPERATOR_CONTEXT_FUNCTIONS;
107  USE_DISPATCH_HELPER;
108 
109  static constexpr const char* VALUE_TAG = "value";
110 
111  public:
113 
114  IsMemberOfOp(const OperatorDef& op, Workspace* ws)
115  : Operator<Context>(op, ws) {
116  auto dtype =
117  static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>(
118  "dtype", TensorProto_DataType_UNDEFINED));
119  switch (dtype) {
120  case TensorProto_DataType_INT32:
121  values_.set(OperatorBase::GetRepeatedArgument<int32_t>(VALUE_TAG));
122  break;
123  case TensorProto_DataType_INT64:
124  values_.set(OperatorBase::GetRepeatedArgument<int64_t>(VALUE_TAG));
125  break;
126  case TensorProto_DataType_BOOL:
127  values_.set(OperatorBase::GetRepeatedArgument<bool>(VALUE_TAG));
128  break;
129  case TensorProto_DataType_STRING:
130  values_.set(OperatorBase::GetRepeatedArgument<std::string>(VALUE_TAG));
131  break;
132  case TensorProto_DataType_UNDEFINED:
133  // If dtype is not provided, values_ will be filled the first time that
134  // DoRunWithType is called.
135  break;
136  default:
137  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
138  }
139  }
140  virtual ~IsMemberOfOp() noexcept {}
141 
142  bool RunOnDevice() override {
143  return DispatchHelper<
145  }
146 
147  template <typename T>
148  bool DoRunWithType() {
149  auto& input = Input(0);
150  auto* output = Output(0);
151  output->ResizeLike(input);
152 
153  if (!values_.has_values()) {
154  values_.set(OperatorBase::GetRepeatedArgument<T>(VALUE_TAG));
155  }
156  const auto& values = values_.get<T>();
157 
158  const T* input_data = input.template data<T>();
159  bool* output_data = output->template mutable_data<bool>();
160  for (int i = 0; i < input.size(); ++i) {
161  output_data[i] = values.find(input_data[i]) != values.end();
162  }
163  return true;
164  }
165 
166  protected:
167  IsMemberOfValueHolder values_;
168 };
169 
170 } // namespace caffe2
171 
172 #endif // CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...