Caffe2 - C++ API
A deep learning, cross platform ML framework
map_ops.h
1 #ifndef CAFFE2_OPERATORS_MAP_OPS_H_
2 #define CAFFE2_OPERATORS_MAP_OPS_H_
3 
4 #include <algorithm>
5 #include <iterator>
6 #include <string>
7 #include <typeinfo>
8 #include <unordered_map>
9 #include <utility>
10 #include <vector>
11 
12 #include "caffe2/core/blob_serialization.h"
13 #include "caffe2/core/context.h"
14 #include "caffe2/core/operator.h"
15 
16 namespace caffe2 {
17 
18 template <typename T>
20  static constexpr const char* name = "unknown";
21 };
22 
23 template <>
24 struct TypeNameTraits<int64_t> {
25  static constexpr const char* name = "int64_t";
26 };
27 
28 template <>
29 struct TypeNameTraits<int32_t> {
30  static constexpr const char* name = "int32_t";
31 };
32 
33 template <typename KEY_T, typename VALUE_T>
34 struct MapTypeTraits {
35  using MapType = std::unordered_map<KEY_T, VALUE_T>;
36  static string MapTypeName() {
37  return string("(std::unordered_map<") + TypeNameTraits<KEY_T>::name + ", " +
39  }
40 };
41 
42 using MapType64To64 = MapTypeTraits<int64_t, int64_t>::MapType;
43 using MapType64To32 = MapTypeTraits<int64_t, int32_t>::MapType;
44 using MapType32To32 = MapTypeTraits<int32_t, int32_t>::MapType;
45 using MapType32To64 = MapTypeTraits<int32_t, int64_t>::MapType;
46 
47 template <class Context>
48 class CreateMapOp final : public Operator<Context> {
49  public:
50  USE_OPERATOR_CONTEXT_FUNCTIONS;
51  CreateMapOp(const OperatorDef& operator_def, Workspace* ws)
52  : Operator<Context>(operator_def, ws) {}
53  ~CreateMapOp() {}
54 
55  bool RunOnDevice() override {
56  TensorProto::DataType key_dtype =
57  static_cast<TensorProto::DataType>(OperatorBase::GetSingleArgument<int>(
58  "key_dtype", TensorProto_DataType_INT32));
59 
61  this, DataTypeToTypeMeta(key_dtype));
62  }
63 
64  template <typename KEY_T>
65  bool DoRunWithType() {
66  TensorProto::DataType value_dtype =
67  static_cast<TensorProto::DataType>(OperatorBase::GetSingleArgument<int>(
68  "value_dtype", TensorProto_DataType_INT32));
69 
70  return DispatchHelper<
72  KEY_T>::call(this, DataTypeToTypeMeta(value_dtype));
73  }
74 
75  template <typename KEY_T, typename VALUE_T>
76  bool DoRunWithType2() {
77  // clear to make sure the map is empty
78  OperatorBase::Output<typename MapTypeTraits<KEY_T, VALUE_T>::MapType>(MAP)
79  ->clear();
80  return true;
81  }
82 
83  template <typename KEY_T>
84  bool DoRunWithOtherType2() {
85  TensorProto::DataType value_dtype =
86  static_cast<TensorProto::DataType>(OperatorBase::GetSingleArgument<int>(
87  "value_dtype", TensorProto_DataType_INT32));
88 
89  CAFFE_THROW(
90  "CreateMap is not implemented on value tensor of type ",
91  DataTypeToTypeMeta(value_dtype).name(),
92  "Consider adding it a type in the list DispatchHelper");
93  }
94 
95  OUTPUT_TAGS(MAP);
96 };
97 
98 template <class Context>
99 class KeyValueToMapOp final : public Operator<Context> {
100  public:
101  USE_OPERATOR_CONTEXT_FUNCTIONS;
102  KeyValueToMapOp(const OperatorDef& operator_def, Workspace* ws)
103  : Operator<Context>(operator_def, ws) {}
104  ~KeyValueToMapOp() {}
105 
106  bool RunOnDevice() override {
108  this, Input(KEYS));
109  }
110 
111  template <typename KEY_T>
112  bool DoRunWithType() {
113  return DispatchHelper<
115  KEY_T>::call(this, Input(VALUES));
116  }
117 
118  template <typename KEY_T, typename VALUE_T>
119  bool DoRunWithType2() {
120  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
121  const auto& key_input = Input(KEYS);
122  const auto& value_input = Input(VALUES);
123 
124  CAFFE_ENFORCE_EQ(key_input.size(), value_input.size());
125 
126  auto* key_data = key_input.template data<KEY_T>();
127  auto* value_data = value_input.template data<VALUE_T>();
128 
129  auto* map_data = OperatorBase::Output<MapType>(MAP);
130 
131  for (int i = 0; i < key_input.size(); ++i) {
132  map_data->emplace(key_data[i], value_data[i]);
133  }
134 
135  return true;
136  }
137 
138  template <typename KEY_T>
139  bool DoRunWithOtherType2() {
140  CAFFE_THROW(
141  "KeyValueToMap is not implemented on value tensor of type ",
142  Input(VALUES).meta().name(),
143  "Consider adding it a type in the list DispatchHelper");
144  }
145 
146  INPUT_TAGS(KEYS, VALUES);
147  OUTPUT_TAGS(MAP);
148 };
149 
150 template <class Context>
151 class MapToKeyValueOp final : public Operator<Context> {
152  public:
153  USE_OPERATOR_CONTEXT_FUNCTIONS;
154  MapToKeyValueOp(const OperatorDef& operator_def, Workspace* ws)
155  : Operator<Context>(operator_def, ws) {}
156  ~MapToKeyValueOp() {}
157 
158  bool RunOnDevice() override {
160  MapType64To64,
161  MapType64To32,
162  MapType32To32,
163  MapType32To64>>::call(this, OperatorBase::InputBlob(MAP));
164  }
165 
166  template <typename MAP_T>
167  bool DoRunWithType() {
168  using key_type = typename MAP_T::key_type;
169  using mapped_type = typename MAP_T::mapped_type;
170  auto& map_data = OperatorBase::Input<MAP_T>(MAP);
171  auto* key_output = Output(KEYS);
172  auto* value_output = Output(VALUES);
173  key_output->Resize(map_data.size());
174  value_output->Resize(map_data.size());
175  auto* key_data = key_output->template mutable_data<key_type>();
176  auto* value_data = value_output->template mutable_data<mapped_type>();
177 
178  for (const auto& it : map_data) {
179  *key_data = it.first;
180  *value_data = it.second;
181  key_data++;
182  value_data++;
183  }
184 
185  return true;
186  }
187 
188  INPUT_TAGS(MAP);
189  OUTPUT_TAGS(KEYS, VALUES);
190 };
191 
192 template <typename KEY_T, typename VALUE_T>
194  public:
195  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
196 
197  void Serialize(
198  const Blob& blob,
199  const string& name,
200  BlobSerializerBase::SerializationAcceptor acceptor) override {
201  CAFFE_ENFORCE(blob.IsType<MapType>());
202  const MapType& map_data = blob.template Get<MapType>();
203  TIndex sz = map_data.size();
204  Tensor<CPUContext> key_tensor;
205  key_tensor.Resize(sz);
206  Tensor<CPUContext> value_tensor;
207  value_tensor.Resize(sz);
208  auto* key_data = key_tensor.mutable_data<KEY_T>();
209  auto* value_data = value_tensor.mutable_data<VALUE_T>();
210  for (const auto& it : map_data) {
211  *key_data = it.first;
212  *value_data = it.second;
213  key_data++;
214  value_data++;
215  }
216 
217  TensorProtos tensor_protos;
219  ser.Serialize(
220  key_tensor, name, tensor_protos.add_protos(), 0, key_tensor.size());
221  ser.Serialize(
222  value_tensor, name, tensor_protos.add_protos(), 0, value_tensor.size());
223 
224  BlobProto blob_proto;
225  blob_proto.set_name(name);
226  blob_proto.set_type(MapTypeTraits<KEY_T, VALUE_T>::MapTypeName());
227  blob_proto.set_content(tensor_protos.SerializeAsString());
228  acceptor(name, blob_proto.SerializeAsString());
229  }
230 };
231 
232 template <typename KEY_T, typename VALUE_T>
234  public:
235  using MapType = typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
236 
237  void Deserialize(const BlobProto& proto, Blob* blob) override {
238  TensorProtos tensor_protos;
239  CAFFE_ENFORCE(
240  tensor_protos.ParseFromString(proto.content()),
241  "Fail to parse TensorProtos");
243  Tensor<CPUContext> key_tensor, value_tensor;
244  deser.Deserialize(tensor_protos.protos(0), &key_tensor);
245  deser.Deserialize(tensor_protos.protos(1), &value_tensor);
246  auto* key_data = key_tensor.data<KEY_T>();
247  auto* value_data = value_tensor.data<VALUE_T>();
248 
249  auto* map_ptr = blob->template GetMutable<MapType>();
250  for (int i = 0; i < key_tensor.size(); ++i) {
251  map_ptr->emplace(key_data[i], value_data[i]);
252  }
253  }
254 };
255 
256 } // namespace caffe2
257 
258 #endif // CAFFE2_OPERATORS_MAP_OPS_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
const T * data() const
Returns a typed pointer of the underlying storage.
Definition: tensor.h:484
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
T * mutable_data()
Returns a typed pointer of the underlying storage.
Definition: tensor.h:578
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool IsType() const
Checks if the content stored in the blob is of type T.
Definition: blob.h:58
BlobSerializerBase is an abstract class that serializes a blob to a string.