1 #ifndef CAFFE2_OPERATORS_MAP_OPS_H_ 2 #define CAFFE2_OPERATORS_MAP_OPS_H_ 8 #include <unordered_map> 12 #include "caffe2/core/blob_serialization.h" 13 #include "caffe2/core/context.h" 14 #include "caffe2/core/operator.h" 20 static constexpr
const char* name =
"unknown";
25 static constexpr
const char* name =
"int64_t";
30 static constexpr
const char* name =
"int32_t";
33 template <
typename KEY_T,
typename VALUE_T>
35 using MapType = std::unordered_map<KEY_T, VALUE_T>;
36 static string MapTypeName() {
42 using MapType64To64 = MapTypeTraits<int64_t, int64_t>::MapType;
43 using MapType64To32 = MapTypeTraits<int64_t, int32_t>::MapType;
44 using MapType32To32 = MapTypeTraits<int32_t, int32_t>::MapType;
45 using MapType32To64 = MapTypeTraits<int32_t, int64_t>::MapType;
47 template <
class Context>
50 USE_OPERATOR_CONTEXT_FUNCTIONS;
55 bool RunOnDevice()
override {
56 TensorProto::DataType key_dtype =
57 static_cast<TensorProto::DataType
>(OperatorBase::GetSingleArgument<int>(
58 "key_dtype", TensorProto_DataType_INT32));
61 this, DataTypeToTypeMeta(key_dtype));
64 template <
typename KEY_T>
65 bool DoRunWithType() {
66 TensorProto::DataType value_dtype =
67 static_cast<TensorProto::DataType
>(OperatorBase::GetSingleArgument<int>(
68 "value_dtype", TensorProto_DataType_INT32));
72 KEY_T>::call(
this, DataTypeToTypeMeta(value_dtype));
75 template <
typename KEY_T,
typename VALUE_T>
76 bool DoRunWithType2() {
78 OperatorBase::Output<typename MapTypeTraits<KEY_T, VALUE_T>::MapType>(MAP)
83 template <
typename KEY_T>
84 bool DoRunWithOtherType2() {
85 TensorProto::DataType value_dtype =
86 static_cast<TensorProto::DataType
>(OperatorBase::GetSingleArgument<int>(
87 "value_dtype", TensorProto_DataType_INT32));
90 "CreateMap is not implemented on value tensor of type ",
91 DataTypeToTypeMeta(value_dtype).name(),
92 "Consider adding it a type in the list DispatchHelper");
98 template <
class Context>
101 USE_OPERATOR_CONTEXT_FUNCTIONS;
106 bool RunOnDevice()
override {
111 template <
typename KEY_T>
112 bool DoRunWithType() {
115 KEY_T>::call(
this, Input(VALUES));
118 template <
typename KEY_T,
typename VALUE_T>
119 bool DoRunWithType2() {
120 using MapType =
typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
121 const auto& key_input = Input(KEYS);
122 const auto& value_input = Input(VALUES);
124 CAFFE_ENFORCE_EQ(key_input.size(), value_input.size());
126 auto* key_data = key_input.template data<KEY_T>();
127 auto* value_data = value_input.template data<VALUE_T>();
129 auto* map_data = OperatorBase::Output<MapType>(MAP);
131 for (
int i = 0; i < key_input.size(); ++i) {
132 map_data->emplace(key_data[i], value_data[i]);
138 template <
typename KEY_T>
139 bool DoRunWithOtherType2() {
141 "KeyValueToMap is not implemented on value tensor of type ",
142 Input(VALUES).meta().name(),
143 "Consider adding it a type in the list DispatchHelper");
146 INPUT_TAGS(KEYS, VALUES);
150 template <
class Context>
153 USE_OPERATOR_CONTEXT_FUNCTIONS;
158 bool RunOnDevice()
override {
163 MapType32To64>>::call(
this, OperatorBase::InputBlob(MAP));
166 template <
typename MAP_T>
167 bool DoRunWithType() {
168 using key_type =
typename MAP_T::key_type;
169 using mapped_type =
typename MAP_T::mapped_type;
170 auto& map_data = OperatorBase::Input<MAP_T>(MAP);
171 auto* key_output = Output(KEYS);
172 auto* value_output = Output(VALUES);
173 key_output->Resize(map_data.size());
174 value_output->Resize(map_data.size());
175 auto* key_data = key_output->template mutable_data<key_type>();
176 auto* value_data = value_output->template mutable_data<mapped_type>();
178 for (
const auto& it : map_data) {
179 *key_data = it.first;
180 *value_data = it.second;
189 OUTPUT_TAGS(KEYS, VALUES);
192 template <
typename KEY_T,
typename VALUE_T>
195 using MapType =
typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
200 BlobSerializerBase::SerializationAcceptor acceptor)
override {
201 CAFFE_ENFORCE(blob.
IsType<MapType>());
202 const MapType& map_data = blob.template Get<MapType>();
203 TIndex sz = map_data.size();
209 auto* value_data = value_tensor.
mutable_data<VALUE_T>();
210 for (
const auto& it : map_data) {
211 *key_data = it.first;
212 *value_data = it.second;
217 TensorProtos tensor_protos;
220 key_tensor, name, tensor_protos.add_protos(), 0, key_tensor.
size());
222 value_tensor, name, tensor_protos.add_protos(), 0, value_tensor.
size());
224 BlobProto blob_proto;
225 blob_proto.set_name(name);
227 blob_proto.set_content(tensor_protos.SerializeAsString());
228 acceptor(name, blob_proto.SerializeAsString());
232 template <
typename KEY_T,
typename VALUE_T>
235 using MapType =
typename MapTypeTraits<KEY_T, VALUE_T>::MapType;
237 void Deserialize(
const BlobProto& proto,
Blob* blob)
override {
238 TensorProtos tensor_protos;
240 tensor_protos.ParseFromString(proto.content()),
241 "Fail to parse TensorProtos");
244 deser.Deserialize(tensor_protos.protos(0), &key_tensor);
245 deser.Deserialize(tensor_protos.protos(1), &value_tensor);
246 auto* key_data = key_tensor.
data<KEY_T>();
247 auto* value_data = value_tensor.
data<VALUE_T>();
249 auto* map_ptr = blob->template GetMutable<MapType>();
250 for (
int i = 0; i < key_tensor.
size(); ++i) {
251 map_ptr->emplace(key_data[i], value_data[i]);
258 #endif // CAFFE2_OPERATORS_MAP_OPS_H_ Blob is a general container that hosts a typed pointer.
const T * data() const
Returns a typed pointer of the underlying storage.
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
TIndex size() const
Returns the size (i.e.
T * mutable_data()
Returns a typed pointer of the underlying storage.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void Resize(Ts...dim_source)
Resizes a tensor.
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool IsType() const
Checks if the content stored in the blob is of type T.
BlobSerializerBase is an abstract class that serializes a blob to a string.