5 #include <unordered_map> 7 #include "caffe2/core/blob_serialization.h" 8 #include "caffe2/core/operator.h" 9 #include "caffe2/core/tensor.h" 13 using IndexKeyTypes = TensorTypes<int32_t, int64_t, std::string>;
14 using TIndexValue = int64_t;
20 : maxElements_{maxElements}
24 void Freeze() { frozen_ =
true; }
26 bool isFrozen()
const {
30 int64_t maxElements()
const {
36 const TypeMeta& Type()
const {
return meta_; }
39 std::lock_guard<std::mutex> guard(dictMutex_);
46 TIndexValue nextId_{1};
47 std::atomic<bool> frozen_{
false};
48 std::mutex dictMutex_;
53 explicit Index(TIndexValue maxElements)
54 :
IndexBase(maxElements, TypeMeta::Make<T>()) {}
56 void Get(
const T* keys, TIndexValue* values,
size_t numKeys) {
58 FrozenGet(keys, values, numKeys);
61 std::lock_guard<std::mutex> lock(dictMutex_);
62 for (
int i = 0; i < numKeys; ++i) {
63 auto it = dict_.find(keys[i]);
64 if (it != dict_.end()) {
65 values[i] = it->second;
66 }
else if (nextId_ < maxElements_) {
67 auto newValue = nextId_++;
68 dict_.insert({keys[i], newValue});
71 CAFFE_THROW(
"Dict max size reached");
76 bool Load(
const T* keys,
size_t numKeys) {
78 numKeys <= maxElements_,
79 "Cannot load index: Tensor is larger than max_elements.");
81 for (
int i = 0; i < numKeys; ++i) {
83 dict.insert({keys[i], i + 1}).second,
84 "Repeated elements found: cannot load into dictionary.");
88 std::lock_guard<std::mutex> lock(dictMutex_);
91 nextId_ = numKeys + 1;
96 template<
typename Ctx>
98 std::lock_guard<std::mutex> lock(dictMutex_);
100 auto outData = out->template mutable_data<T>();
101 for (
const auto& entry : dict_) {
102 outData[entry.second - 1] = entry.first;
108 void FrozenGet(
const T* keys, TIndexValue* values,
size_t numKeys) {
109 for (
int i = 0; i < numKeys; ++i) {
110 auto it = dict_.find(keys[i]);
111 values[i] = it != dict_.end() ? it->second : 0;
115 std::unordered_map<T, TIndexValue> dict_;
124 maxElements_(OperatorBase::GetSingleArgument<int>(
126 std::numeric_limits<int>::max())) {}
128 bool RunOnDevice()
override {
129 *OperatorBase::Output<std::unique_ptr<IndexBase>>(0) =
130 std::unique_ptr<IndexBase>(
new Index<T>(maxElements_));
135 TIndexValue maxElements_;
143 bool RunOnDevice()
override {
146 template <
typename T>
147 bool DoRunWithType() {
148 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
149 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
150 CAFFE_ENFORCE(dict,
"Wrong dictionary type given input keys.");
151 const auto& keys = Input(1);
152 auto* values = Output(0);
153 values->ResizeLike(keys);
154 dict->Get(keys.data<T>(), values->mutable_data<TIndexValue>(), keys.size());
164 OperatorBase::GetSingleArgument<int>(
"skip_first_entry", 0)) {}
166 bool RunOnDevice()
override {
169 template <
typename T>
170 bool DoRunWithType() {
171 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
172 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
173 CAFFE_ENFORCE(dict,
"Wrong dictionary type given input keys.");
174 const auto& keys = Input(1);
175 const auto* keys_data = keys.data<T>();
176 auto keys_size = keys.size();
177 if (skipFirstEntry_) {
178 CAFFE_ENFORCE(keys.size() > 0);
182 return dict->Load(keys_data, keys_size);
186 bool skipFirstEntry_;
194 bool RunOnDevice()
override {
195 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
199 template <
typename T>
200 bool DoRunWithType() {
201 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
202 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
204 return dict->Store(Output(0));
213 bool RunOnDevice()
override {
214 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
225 bool RunOnDevice()
override {
226 auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
227 auto* out = Output(0);
228 out->Resize(std::vector<TIndex>{});
229 *out->mutable_data<TIndexValue>() = base->Size();
244 OPERATOR_SCHEMA(IntIndexCreate)
248 Creates a dictionary that maps int32 keys to consecutive integers 249 from 1 to max_elements. Zero is reserved for unknown keys. 251 .Arg("max_elements",
"Max number of elements, including the zero entry.")
252 .Output(0,
"handler",
"Pointer to an Index instance.");
254 OPERATOR_SCHEMA(LongIndexCreate)
258 Creates a dictionary that maps int64 keys to consecutive integers 259 from 1 to max_elements. Zero is reserved for unknown keys. 261 .Arg("max_elements",
"Max number of elements, including the zero entry.")
262 .Output(0,
"handler",
"Pointer to an Index instance.");
264 OPERATOR_SCHEMA(StringIndexCreate)
268 Creates a dictionary that maps string keys to consecutive integers 269 from 1 to max_elements. Zero is reserved for unknown keys. 271 .Arg("max_elements",
"Max number of elements, including the zero entry.")
272 .Output(0,
"handle",
"Pointer to an Index instance.");
274 OPERATOR_SCHEMA(IndexGet)
278 Given an index handle and a tensor of keys, return an Int tensor of same shape 279 containing the indices for each of the keys. If the index is frozen, unknown 280 entries are given index 0. Otherwise, new entries are added into the index. 281 If an insert is necessary but max_elements has been reached, fail. 283 .Input(0, "handle",
"Pointer to an Index instance.")
284 .Input(1,
"keys",
"Tensor of keys to be looked up.")
285 .Output(0,
"indices",
"Indices for each of the keys.");
287 OPERATOR_SCHEMA(IndexFreeze)
291 Freezes the given index, disallowing creation of new index entries. 292 Should not be called concurrently with IndexGet. 294 .Input(0, "handle",
"Pointer to an Index instance.")
295 .Output(0,
"handle",
"The input handle.")
296 .EnforceInplace({{0, 0}});
298 OPERATOR_SCHEMA(IndexLoad)
302 Loads the index from the given 1-D tensor. Elements in the tensor will be given 303 consecutive indexes starting at 1. Fails if tensor contains repeated elements. 305 .Input(0, "handle",
"Pointer to an Index instance.")
306 .Input(1,
"items",
"1-D tensor with elements starting with index 1.")
307 .Output(0,
"handle",
"The input handle.")
308 .EnforceInplace({{0, 0}})
311 "If set, skips the first entry of the tensor. This allows " 312 "to load tensors that are aligned with an embedding, where the first " 313 "entry corresponds to the default 0 index entry.");
315 OPERATOR_SCHEMA(IndexStore)
319 Stores the keys of this index in a 1-D tensor. Since element 0 is reserved 320 for unknowns, the first element of the output tensor will be element of index 1. 322 .Input(0, "handle",
"Pointer to an Index instance.")
323 .Output(0,
"items",
"1-D tensor with elements starting with index 1.");
325 OPERATOR_SCHEMA(IndexSize)
329 Returns the number of entries currently present in the index. 331 .Input(0, "handle",
"Pointer to an Index instance.")
332 .Output(0,
"items",
"Scalar int64 tensor with number of entries.");
335 NO_GRADIENT(IntIndexCreate);
336 NO_GRADIENT(LongIndexCreate);
337 NO_GRADIENT(StringIndexCreate);
338 SHOULD_NOT_DO_GRADIENT(IndexFreeze);
339 SHOULD_NOT_DO_GRADIENT(IndexLoad);
340 SHOULD_NOT_DO_GRADIENT(IndexStore);
341 SHOULD_NOT_DO_GRADIENT(IndexSize);
351 SerializationAcceptor acceptor)
override {
352 auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
354 auto* tensor_out = tensor_blob.template GetMutable<Tensor<CPUContext>>();
356 if (base->Type().Match<std::string>()) {
357 doStore<std::string>(base, tensor_out);
358 }
else if (base->Type().Match<int32_t>()) {
359 doStore<int32_t>(base, tensor_out);
360 }
else if (base->Type().Match<int64_t>()) {
361 doStore<int64_t>(base, tensor_out);
363 CAFFE_THROW(
"Index of this type can't be serialized.");
367 tensor_out->size() <= std::numeric_limits<int32_t>::max(),
368 "Index too large to be serialized.");
369 BlobProto blob_proto;
372 *tensor_out, name, blob_proto.mutable_tensor(), 0, tensor_out->size());
373 blob_proto.set_name(name);
374 blob_proto.set_type(
"std::unique_ptr<caffe2::IndexBase>");
376 std::ostringstream os;
377 os << base->maxElements() <<
" " << base->isFrozen();
378 blob_proto.set_content(os.str());
380 acceptor(name, blob_proto.SerializeAsString());
384 template <
typename T>
386 const std::unique_ptr<IndexBase>& base,
388 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
389 CAFFE_ENFORCE(dict,
"Wrong dictionary type.");
390 dict->Store(tensor_out);
396 void Deserialize(
const BlobProto& proto,
Blob* blob)
override {
399 deser.Deserialize(proto, &tensor_blob);
401 std::istringstream is(proto.content());
402 int64_t maxElements{std::numeric_limits<int64_t>::max()};
403 bool isFrozen{
false};
404 is >> maxElements >> isFrozen;
406 auto& tensor_in = tensor_blob.template Get<Tensor<CPUContext>>();
407 auto* base = blob->template GetMutable<std::unique_ptr<IndexBase>>();
409 if (tensor_in.IsType<std::string>()) {
410 doLoad<std::string>(base, maxElements, tensor_in);
411 }
else if (tensor_in.IsType<int32_t>()) {
412 doLoad<int32_t>(base, maxElements, tensor_in);
413 }
else if (tensor_in.IsType<int64_t>()) {
414 doLoad<int64_t>(base, maxElements, tensor_in);
416 CAFFE_THROW(
"Index of this type cannot be deserialized.");
425 template <
typename T>
427 std::unique_ptr<IndexBase>* base,
430 base->reset(
new Index<T>(maxElements));
431 auto* dict = dynamic_cast_if_rtti<Index<T>*>(base->get());
432 dict->Load(tensor_in.
data<T>(), tensor_in.
size());
436 CAFFE_KNOWN_TYPE(std::unique_ptr<caffe2::IndexBase>);
438 REGISTER_BLOB_SERIALIZER(
441 REGISTER_BLOB_DESERIALIZER(
442 std::unique_ptr<caffe2::IndexBase>,
Blob is a general container that hosts a typed pointer.
const T * data() const
Returns a typed pointer of the underlying storage.
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
TIndex size() const
Returns the size (i.e.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void Resize(Ts...dim_source)
Resizes a tensor.
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
BlobSerializerBase is an abstract class that serializes a blob to a string.