Caffe2 - C++ API
A deep learning, cross platform ML framework
index_ops.cc
1 #include <atomic>
2 #include <limits>
3 #include <mutex>
4 #include <sstream>
5 #include <unordered_map>
6 #include <vector>
7 #include "caffe2/core/blob_serialization.h"
8 #include "caffe2/core/operator.h"
9 #include "caffe2/core/tensor.h"
10 
11 namespace caffe2 {
12 namespace {
13 using IndexKeyTypes = TensorTypes<int32_t, int64_t, std::string>;
14 using TIndexValue = int64_t;
15 } // namespace
16 
17 struct IndexBase {
18  public:
19  IndexBase(TIndexValue maxElements, const TypeMeta& type)
20  : maxElements_{maxElements}
21  , meta_(type)
22  , frozen_{false} {}
23 
24  void Freeze() { frozen_ = true; }
25 
26  bool isFrozen() const {
27  return frozen_;
28  }
29 
30  int64_t maxElements() const {
31  return maxElements_;
32  }
33 
34  virtual ~IndexBase() {}
35 
36  const TypeMeta& Type() const { return meta_; }
37 
38  TIndexValue Size() {
39  std::lock_guard<std::mutex> guard(dictMutex_);
40  return nextId_;
41  }
42 
43  protected:
44  int64_t maxElements_;
45  TypeMeta meta_;
46  TIndexValue nextId_{1}; // guarded by dictMutex_
47  std::atomic<bool> frozen_{false};
48  std::mutex dictMutex_;
49 };
50 
51 template<typename T>
52 struct Index: IndexBase {
53  explicit Index(TIndexValue maxElements)
54  : IndexBase(maxElements, TypeMeta::Make<T>()) {}
55 
56  void Get(const T* keys, TIndexValue* values, size_t numKeys) {
57  if (frozen_) {
58  FrozenGet(keys, values, numKeys);
59  return;
60  }
61  std::lock_guard<std::mutex> lock(dictMutex_);
62  for (int i = 0; i < numKeys; ++i) {
63  auto it = dict_.find(keys[i]);
64  if (it != dict_.end()) {
65  values[i] = it->second;
66  } else if (nextId_ < maxElements_) {
67  auto newValue = nextId_++;
68  dict_.insert({keys[i], newValue});
69  values[i] = newValue;
70  } else {
71  CAFFE_THROW("Dict max size reached");
72  }
73  }
74  }
75 
76  bool Load(const T* keys, size_t numKeys) {
77  CAFFE_ENFORCE(
78  numKeys <= maxElements_,
79  "Cannot load index: Tensor is larger than max_elements.");
80  decltype(dict_) dict;
81  for (int i = 0; i < numKeys; ++i) {
82  CAFFE_ENFORCE(
83  dict.insert({keys[i], i + 1}).second,
84  "Repeated elements found: cannot load into dictionary.");
85  }
86  // assume no `get` is inflight while this happens
87  {
88  std::lock_guard<std::mutex> lock(dictMutex_);
89  // let the old dict get destructed outside of the lock
90  dict_.swap(dict);
91  nextId_ = numKeys + 1;
92  }
93  return true;
94  }
95 
96  template<typename Ctx>
97  bool Store(Tensor<Ctx>* out) {
98  std::lock_guard<std::mutex> lock(dictMutex_);
99  out->Resize(nextId_ - 1);
100  auto outData = out->template mutable_data<T>();
101  for (const auto& entry : dict_) {
102  outData[entry.second - 1] = entry.first;
103  }
104  return true;
105  }
106 
107  private:
108  void FrozenGet(const T* keys, TIndexValue* values, size_t numKeys) {
109  for (int i = 0; i < numKeys; ++i) {
110  auto it = dict_.find(keys[i]);
111  values[i] = it != dict_.end() ? it->second : 0;
112  }
113  }
114 
115  std::unordered_map<T, TIndexValue> dict_;
116 };
117 
118 // TODO(azzolini): support sizes larger than int32
119 template<class T>
120 class IndexCreateOp: public Operator<CPUContext> {
121  public:
122  IndexCreateOp(const OperatorDef& operator_def, Workspace* ws)
123  : Operator(operator_def, ws),
124  maxElements_(OperatorBase::GetSingleArgument<int>(
125  "max_elements",
126  std::numeric_limits<int>::max())) {}
127 
128  bool RunOnDevice() override {
129  *OperatorBase::Output<std::unique_ptr<IndexBase>>(0) =
130  std::unique_ptr<IndexBase>(new Index<T>(maxElements_));
131  return true;
132  }
133 
134  private:
135  TIndexValue maxElements_;
136 };
137 
138 class IndexGetOp: public Operator<CPUContext> {
139  public:
140  IndexGetOp(const OperatorDef& operator_def, Workspace* ws)
141  : Operator(operator_def, ws) {}
142 
143  bool RunOnDevice() override {
144  return DispatchHelper<IndexKeyTypes>::call(this, Input(1));
145  }
146  template <typename T>
147  bool DoRunWithType() {
148  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
149  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
150  CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys.");
151  const auto& keys = Input(1);
152  auto* values = Output(0);
153  values->ResizeLike(keys);
154  dict->Get(keys.data<T>(), values->mutable_data<TIndexValue>(), keys.size());
155  return true;
156  }
157 };
158 
159 class IndexLoadOp: public Operator<CPUContext> {
160  public:
161  IndexLoadOp(const OperatorDef& operator_def, Workspace* ws)
162  : Operator(operator_def, ws),
163  skipFirstEntry_(
164  OperatorBase::GetSingleArgument<int>("skip_first_entry", 0)) {}
165 
166  bool RunOnDevice() override {
167  return DispatchHelper<IndexKeyTypes>::call(this, Input(1));
168  }
169  template <typename T>
170  bool DoRunWithType() {
171  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
172  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
173  CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys.");
174  const auto& keys = Input(1);
175  const auto* keys_data = keys.data<T>();
176  auto keys_size = keys.size();
177  if (skipFirstEntry_) {
178  CAFFE_ENFORCE(keys.size() > 0);
179  ++keys_data;
180  --keys_size;
181  }
182  return dict->Load(keys_data, keys_size);
183  }
184 
185  private:
186  bool skipFirstEntry_;
187 };
188 
189 class IndexStoreOp: public Operator<CPUContext> {
190  public:
191  IndexStoreOp(const OperatorDef& operator_def, Workspace* ws)
192  : Operator(operator_def, ws) {}
193 
194  bool RunOnDevice() override {
195  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
196  return DispatchHelper<IndexKeyTypes>::call(this, base->Type());
197  }
198 
199  template <typename T>
200  bool DoRunWithType() {
201  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
202  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
203  CAFFE_ENFORCE(dict);
204  return dict->Store(Output(0));
205  }
206 };
207 
208 class IndexFreezeOp: public Operator<CPUContext> {
209  public:
210  IndexFreezeOp(const OperatorDef& operator_def, Workspace* ws)
211  : Operator(operator_def, ws) {}
212 
213  bool RunOnDevice() override {
214  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
215  base->Freeze();
216  return true;
217  }
218 };
219 
220 class IndexSizeOp : public Operator<CPUContext> {
221  public:
222  IndexSizeOp(const OperatorDef& operator_def, Workspace* ws)
223  : Operator(operator_def, ws) {}
224 
225  bool RunOnDevice() override {
226  auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0);
227  auto* out = Output(0);
228  out->Resize(std::vector<TIndex>{});
229  *out->mutable_data<TIndexValue>() = base->Size();
230  return true;
231  }
232 };
233 
234 REGISTER_CPU_OPERATOR(IntIndexCreate, IndexCreateOp<int32_t>);
235 REGISTER_CPU_OPERATOR(LongIndexCreate, IndexCreateOp<int64_t>);
236 REGISTER_CPU_OPERATOR(StringIndexCreate, IndexCreateOp<std::string>);
237 
238 REGISTER_CPU_OPERATOR(IndexGet, IndexGetOp);
239 REGISTER_CPU_OPERATOR(IndexLoad, IndexLoadOp);
240 REGISTER_CPU_OPERATOR(IndexStore, IndexStoreOp);
241 REGISTER_CPU_OPERATOR(IndexFreeze, IndexFreezeOp);
242 REGISTER_CPU_OPERATOR(IndexSize, IndexSizeOp);
243 
244 OPERATOR_SCHEMA(IntIndexCreate)
245  .NumInputs(0)
246  .NumOutputs(1)
247  .SetDoc(R"DOC(
248 Creates a dictionary that maps int32 keys to consecutive integers
249 from 1 to max_elements. Zero is reserved for unknown keys.
250 )DOC")
251  .Arg("max_elements", "Max number of elements, including the zero entry.")
252  .Output(0, "handler", "Pointer to an Index instance.");
253 
254 OPERATOR_SCHEMA(LongIndexCreate)
255  .NumInputs(0)
256  .NumOutputs(1)
257  .SetDoc(R"DOC(
258 Creates a dictionary that maps int64 keys to consecutive integers
259 from 1 to max_elements. Zero is reserved for unknown keys.
260 )DOC")
261  .Arg("max_elements", "Max number of elements, including the zero entry.")
262  .Output(0, "handler", "Pointer to an Index instance.");
263 
264 OPERATOR_SCHEMA(StringIndexCreate)
265  .NumInputs(0)
266  .NumOutputs(1)
267  .SetDoc(R"DOC(
268 Creates a dictionary that maps string keys to consecutive integers
269 from 1 to max_elements. Zero is reserved for unknown keys.
270 )DOC")
271  .Arg("max_elements", "Max number of elements, including the zero entry.")
272  .Output(0, "handle", "Pointer to an Index instance.");
273 
274 OPERATOR_SCHEMA(IndexGet)
275  .NumInputs(2)
276  .NumOutputs(1)
277  .SetDoc(R"DOC(
278 Given an index handle and a tensor of keys, return an Int tensor of same shape
279 containing the indices for each of the keys. If the index is frozen, unknown
280 entries are given index 0. Otherwise, new entries are added into the index.
281 If an insert is necessary but max_elements has been reached, fail.
282 )DOC")
283  .Input(0, "handle", "Pointer to an Index instance.")
284  .Input(1, "keys", "Tensor of keys to be looked up.")
285  .Output(0, "indices", "Indices for each of the keys.");
286 
287 OPERATOR_SCHEMA(IndexFreeze)
288  .NumInputs(1)
289  .NumOutputs(1)
290  .SetDoc(R"DOC(
291 Freezes the given index, disallowing creation of new index entries.
292 Should not be called concurrently with IndexGet.
293 )DOC")
294  .Input(0, "handle", "Pointer to an Index instance.")
295  .Output(0, "handle", "The input handle.")
296  .EnforceInplace({{0, 0}});
297 
298 OPERATOR_SCHEMA(IndexLoad)
299  .NumInputs(2)
300  .NumOutputs(1)
301  .SetDoc(R"DOC(
302 Loads the index from the given 1-D tensor. Elements in the tensor will be given
303 consecutive indexes starting at 1. Fails if tensor contains repeated elements.
304 )DOC")
305  .Input(0, "handle", "Pointer to an Index instance.")
306  .Input(1, "items", "1-D tensor with elements starting with index 1.")
307  .Output(0, "handle", "The input handle.")
308  .EnforceInplace({{0, 0}})
309  .Arg(
310  "skip_first_entry",
311  "If set, skips the first entry of the tensor. This allows "
312  "to load tensors that are aligned with an embedding, where the first "
313  "entry corresponds to the default 0 index entry.");
314 
315 OPERATOR_SCHEMA(IndexStore)
316  .NumInputs(1)
317  .NumOutputs(1)
318  .SetDoc(R"DOC(
319 Stores the keys of this index in a 1-D tensor. Since element 0 is reserved
320 for unknowns, the first element of the output tensor will be element of index 1.
321 )DOC")
322  .Input(0, "handle", "Pointer to an Index instance.")
323  .Output(0, "items", "1-D tensor with elements starting with index 1.");
324 
325 OPERATOR_SCHEMA(IndexSize)
326  .NumInputs(1)
327  .NumOutputs(1)
328  .SetDoc(R"DOC(
329 Returns the number of entries currently present in the index.
330 )DOC")
331  .Input(0, "handle", "Pointer to an Index instance.")
332  .Output(0, "items", "Scalar int64 tensor with number of entries.");
333 
334 NO_GRADIENT(IndexGetOp);
335 NO_GRADIENT(IntIndexCreate);
336 NO_GRADIENT(LongIndexCreate);
337 NO_GRADIENT(StringIndexCreate);
338 SHOULD_NOT_DO_GRADIENT(IndexFreeze);
339 SHOULD_NOT_DO_GRADIENT(IndexLoad);
340 SHOULD_NOT_DO_GRADIENT(IndexStore);
341 SHOULD_NOT_DO_GRADIENT(IndexSize);
342 
344  public:
345  IndexSerializer() {}
346  ~IndexSerializer() {}
347 
348  void Serialize(
349  const Blob& blob,
350  const string& name,
351  SerializationAcceptor acceptor) override {
352  auto& base = blob.template Get<std::unique_ptr<IndexBase>>();
353  Blob tensor_blob;
354  auto* tensor_out = tensor_blob.template GetMutable<Tensor<CPUContext>>();
355 
356  if (base->Type().Match<std::string>()) {
357  doStore<std::string>(base, tensor_out);
358  } else if (base->Type().Match<int32_t>()) {
359  doStore<int32_t>(base, tensor_out);
360  } else if (base->Type().Match<int64_t>()) {
361  doStore<int64_t>(base, tensor_out);
362  } else {
363  CAFFE_THROW("Index of this type can't be serialized.");
364  }
365 
366  CAFFE_ENFORCE(
367  tensor_out->size() <= std::numeric_limits<int32_t>::max(),
368  "Index too large to be serialized.");
369  BlobProto blob_proto;
371  ser.Serialize(
372  *tensor_out, name, blob_proto.mutable_tensor(), 0, tensor_out->size());
373  blob_proto.set_name(name);
374  blob_proto.set_type("std::unique_ptr<caffe2::IndexBase>");
375 
376  std::ostringstream os;
377  os << base->maxElements() << " " << base->isFrozen();
378  blob_proto.set_content(os.str());
379 
380  acceptor(name, blob_proto.SerializeAsString());
381  }
382 
383  private:
384  template <typename T>
385  void doStore(
386  const std::unique_ptr<IndexBase>& base,
387  Tensor<CPUContext>* tensor_out) {
388  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get());
389  CAFFE_ENFORCE(dict, "Wrong dictionary type.");
390  dict->Store(tensor_out);
391  }
392 };
393 
395  public:
396  void Deserialize(const BlobProto& proto, Blob* blob) override {
398  Blob tensor_blob;
399  deser.Deserialize(proto, &tensor_blob);
400 
401  std::istringstream is(proto.content());
402  int64_t maxElements{std::numeric_limits<int64_t>::max()};
403  bool isFrozen{false};
404  is >> maxElements >> isFrozen;
405 
406  auto& tensor_in = tensor_blob.template Get<Tensor<CPUContext>>();
407  auto* base = blob->template GetMutable<std::unique_ptr<IndexBase>>();
408 
409  if (tensor_in.IsType<std::string>()) {
410  doLoad<std::string>(base, maxElements, tensor_in);
411  } else if (tensor_in.IsType<int32_t>()) {
412  doLoad<int32_t>(base, maxElements, tensor_in);
413  } else if (tensor_in.IsType<int64_t>()) {
414  doLoad<int64_t>(base, maxElements, tensor_in);
415  } else {
416  CAFFE_THROW("Index of this type cannot be deserialized.");
417  }
418 
419  if (isFrozen) {
420  (*base)->Freeze();
421  }
422  }
423 
424  private:
425  template <typename T>
426  void doLoad(
427  std::unique_ptr<IndexBase>* base,
428  int64_t maxElements,
429  const Tensor<CPUContext>& tensor_in) {
430  base->reset(new Index<T>(maxElements));
431  auto* dict = dynamic_cast_if_rtti<Index<T>*>(base->get());
432  dict->Load(tensor_in.data<T>(), tensor_in.size());
433  }
434 };
435 
436 CAFFE_KNOWN_TYPE(std::unique_ptr<caffe2::IndexBase>);
437 
438 REGISTER_BLOB_SERIALIZER(
439  (TypeMeta::Id<std::unique_ptr<caffe2::IndexBase>>()),
441 REGISTER_BLOB_DESERIALIZER(
442  std::unique_ptr<caffe2::IndexBase>,
444 
445 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
const T * data() const
Returns a typed pointer of the underlying storage.
Definition: tensor.h:484
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:93
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:593
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void Resize(Ts...dim_source)
Resizes a tensor.
Definition: tensor.h:288
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:88
BlobSerializerBase is an abstract class that serializes a blob to a string.