Caffe2 - C++ API
A deep learning, cross platform ML framework
db.h
1 #ifndef CAFFE2_CORE_DB_H_
2 #define CAFFE2_CORE_DB_H_
3 
4 #include <mutex>
5 
6 #include "caffe2/core/blob_serialization.h"
7 #include "caffe2/core/registry.h"
8 #include "caffe2/proto/caffe2.pb.h"
9 
10 namespace caffe2 {
11 namespace db {
12 
17 enum Mode { READ, WRITE, NEW };
18 
22 class Cursor {
23  public:
24  Cursor() { }
25  virtual ~Cursor() { }
31  virtual void Seek(const string& key) = 0;
32  virtual bool SupportsSeek() { return false; }
36  virtual void SeekToFirst() = 0;
40  virtual void Next() = 0;
44  virtual string key() = 0;
48  virtual string value() = 0;
53  virtual bool Valid() = 0;
54 
55  DISABLE_COPY_AND_ASSIGN(Cursor);
56 };
57 
61 class Transaction {
62  public:
63  Transaction() { }
64  virtual ~Transaction() { }
68  virtual void Put(const string& key, const string& value) = 0;
72  virtual void Commit() = 0;
73 
74  DISABLE_COPY_AND_ASSIGN(Transaction);
75 };
76 
80 class DB {
81  public:
82  DB(const string& /*source*/, Mode mode) : mode_(mode) {}
83  virtual ~DB() { }
87  virtual void Close() = 0;
92  virtual std::unique_ptr<Cursor> NewCursor() = 0;
97  virtual std::unique_ptr<Transaction> NewTransaction() = 0;
98 
99  protected:
100  Mode mode_;
101 
102  DISABLE_COPY_AND_ASSIGN(DB);
103 };
104 
105 // Database classes are registered by their names so we can do optional
106 // dependencies.
107 CAFFE_DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode);
108 #define REGISTER_CAFFE2_DB(name, ...) \
109  CAFFE_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__)
110 
117 inline unique_ptr<DB> CreateDB(
118  const string& db_type, const string& source, Mode mode) {
119  auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
120  VLOG(1) << ((!result) ? "not found db " : "found db ") << db_type;
121  return result;
122 }
123 
127 inline bool DBExists(const string& db_type, const string& full_db_name) {
128  // Warning! We assume that creating a DB throws an exception if the DB
129  // does not exist. If the DB constructor does not follow this design
130  // pattern,
131  // the returned output (the existence tensor) can be wrong.
132  try {
133  std::unique_ptr<DB> db(
134  caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
135  return true;
136  } catch (...) {
137  return false;
138  }
139 }
140 
144 class DBReader {
145  public:
146 
147  friend class DBReaderSerializer;
148  DBReader() {}
149 
150  DBReader(
151  const string& db_type,
152  const string& source,
153  const int32_t num_shards = 1,
154  const int32_t shard_id = 0) {
155  Open(db_type, source, num_shards, shard_id);
156  }
157 
158  explicit DBReader(const DBReaderProto& proto) {
159  Open(proto.db_type(), proto.source());
160  if (proto.has_key()) {
161  CAFFE_ENFORCE(cursor_->SupportsSeek(),
162  "Encountering a proto that needs seeking but the db type "
163  "does not support it.");
164  cursor_->Seek(proto.key());
165  }
166  num_shards_ = 1;
167  shard_id_ = 0;
168  }
169 
170  explicit DBReader(std::unique_ptr<DB> db)
171  : db_type_("<memory-type>"),
172  source_("<memory-source>"),
173  db_(std::move(db)) {
174  CAFFE_ENFORCE(db_.get(), "Passed null db");
175  cursor_ = db_->NewCursor();
176  }
177 
178  void Open(
179  const string& db_type,
180  const string& source,
181  const int32_t num_shards = 1,
182  const int32_t shard_id = 0) {
183  // Note(jiayq): resetting is needed when we re-open e.g. leveldb where no
184  // concurrent access is allowed.
185  cursor_.reset();
186  db_.reset();
187  db_type_ = db_type;
188  source_ = source;
189  db_ = CreateDB(db_type_, source_, READ);
190  CAFFE_ENFORCE(db_, "Cannot open db: ", source_, " of type ", db_type_);
191  InitializeCursor(num_shards, shard_id);
192  }
193 
194  void Open(
195  unique_ptr<DB>&& db,
196  const int32_t num_shards = 1,
197  const int32_t shard_id = 0) {
198  cursor_.reset();
199  db_.reset();
200  db_ = std::move(db);
201  CAFFE_ENFORCE(db_.get(), "Passed null db");
202  InitializeCursor(num_shards, shard_id);
203  }
204 
205  public:
222  void Read(string* key, string* value) const {
223  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
224  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
225  *key = cursor_->key();
226  *value = cursor_->value();
227 
228  // In sharded mode, each read skips num_shards_ records
229  for (int s = 0; s < num_shards_; s++) {
230  cursor_->Next();
231  if (!cursor_->Valid()) {
232  MoveToBeginning();
233  break;
234  }
235  }
236  }
237 
241  void SeekToFirst() const {
242  CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized.");
243  std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
244  MoveToBeginning();
245  }
246 
254  inline Cursor* cursor() const {
255  LOG(ERROR) << "Usually for a DBReader you should use Read() to be "
256  "thread safe. Consider refactoring your code.";
257  return cursor_.get();
258  }
259 
260  private:
261  void InitializeCursor(const int32_t num_shards, const int32_t shard_id) {
262  CAFFE_ENFORCE(num_shards >= 1);
263  CAFFE_ENFORCE(shard_id >= 0);
264  CAFFE_ENFORCE(shard_id < num_shards);
265  num_shards_ = num_shards;
266  shard_id_ = shard_id;
267  cursor_ = db_->NewCursor();
268  SeekToFirst();
269  }
270 
271  void MoveToBeginning() const {
272  cursor_->SeekToFirst();
273  for (auto s = 0; s < shard_id_; s++) {
274  cursor_->Next();
275  CAFFE_ENFORCE(
276  cursor_->Valid(), "Db has less rows than shard id: ", s, shard_id_);
277  }
278  }
279 
280  string db_type_;
281  string source_;
282  unique_ptr<DB> db_;
283  unique_ptr<Cursor> cursor_;
284  mutable std::mutex reader_mutex_;
285  uint32_t num_shards_;
286  uint32_t shard_id_;
287 
288  DISABLE_COPY_AND_ASSIGN(DBReader);
289 };
290 
292  public:
297  void Serialize(
298  const Blob& blob,
299  const string& name,
300  BlobSerializerBase::SerializationAcceptor acceptor) override;
301 };
302 
304  public:
305  void Deserialize(const BlobProto& proto, Blob* blob) override;
306 };
307 
308 } // namespace db
309 } // namespace caffe2
310 
311 #endif // CAFFE2_CORE_DB_H_
virtual bool Valid()=0
Returns whether the current location is valid - for example, if we have reached the end of the databa...
virtual string key()=0
Returns the current key.
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
virtual void Seek(const string &key)=0
Seek to a specific key (or if the key does not exist, seek to the immediate next).
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:222
An abstract class for the current database transaction while writing.
Definition: db.h:61
An abstract class for the cursor of the database while reading.
Definition: db.h:22
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:254
virtual void SeekToFirst()=0
Seek to the first key in the database.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
void SeekToFirst() const
Seeks to the first key.
Definition: db.h:241
virtual void Next()=0
Go to the next location in the database.
virtual string value()=0
Returns the current value.
BlobSerializerBase is an abstract class that serializes a blob to a string.