1 #ifndef CAFFE2_CORE_DB_H_ 2 #define CAFFE2_CORE_DB_H_ 6 #include "caffe2/core/blob_serialization.h" 7 #include "caffe2/core/registry.h" 8 #include "caffe2/proto/caffe2.pb.h" 17 enum Mode { READ, WRITE, NEW };
31 virtual void Seek(
const string&
key) = 0;
32 virtual bool SupportsSeek() {
return false; }
40 virtual void Next() = 0;
44 virtual string key() = 0;
48 virtual string value() = 0;
53 virtual bool Valid() = 0;
55 DISABLE_COPY_AND_ASSIGN(
Cursor);
68 virtual void Put(
const string&
key,
const string&
value) = 0;
72 virtual void Commit() = 0;
82 DB(
const string& , Mode mode) : mode_(mode) {}
87 virtual void Close() = 0;
92 virtual std::unique_ptr<Cursor> NewCursor() = 0;
97 virtual std::unique_ptr<Transaction> NewTransaction() = 0;
102 DISABLE_COPY_AND_ASSIGN(
DB);
107 CAFFE_DECLARE_REGISTRY(Caffe2DBRegistry,
DB,
const string&, Mode);
108 #define REGISTER_CAFFE2_DB(name, ...) \ 109 CAFFE_REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__) 117 inline unique_ptr<DB> CreateDB(
118 const string& db_type,
const string& source, Mode mode) {
119 auto result = Caffe2DBRegistry()->Create(db_type, source, mode);
120 VLOG(1) << ((!result) ?
"not found db " :
"found db ") << db_type;
127 inline bool DBExists(
const string& db_type,
const string& full_db_name) {
133 std::unique_ptr<DB> db(
134 caffe2::db::CreateDB(db_type, full_db_name, caffe2::db::READ));
151 const string& db_type,
152 const string& source,
153 const int32_t num_shards = 1,
154 const int32_t shard_id = 0) {
155 Open(db_type, source, num_shards, shard_id);
158 explicit DBReader(
const DBReaderProto& proto) {
159 Open(proto.db_type(), proto.source());
160 if (proto.has_key()) {
161 CAFFE_ENFORCE(cursor_->SupportsSeek(),
162 "Encountering a proto that needs seeking but the db type " 163 "does not support it.");
164 cursor_->Seek(proto.key());
170 explicit DBReader(std::unique_ptr<DB> db)
171 : db_type_(
"<memory-type>"),
172 source_(
"<memory-source>"),
174 CAFFE_ENFORCE(db_.get(),
"Passed null db");
175 cursor_ = db_->NewCursor();
179 const string& db_type,
180 const string& source,
181 const int32_t num_shards = 1,
182 const int32_t shard_id = 0) {
189 db_ = CreateDB(db_type_, source_, READ);
190 CAFFE_ENFORCE(db_,
"Cannot open db: ", source_,
" of type ", db_type_);
191 InitializeCursor(num_shards, shard_id);
196 const int32_t num_shards = 1,
197 const int32_t shard_id = 0) {
201 CAFFE_ENFORCE(db_.get(),
"Passed null db");
202 InitializeCursor(num_shards, shard_id);
223 CAFFE_ENFORCE(cursor_ !=
nullptr,
"Reader not initialized.");
224 std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
225 *key = cursor_->key();
226 *value = cursor_->value();
229 for (
int s = 0; s < num_shards_; s++) {
231 if (!cursor_->Valid()) {
242 CAFFE_ENFORCE(cursor_ !=
nullptr,
"Reader not initialized.");
243 std::unique_lock<std::mutex> mutex_lock(reader_mutex_);
255 LOG(ERROR) <<
"Usually for a DBReader you should use Read() to be " 256 "thread safe. Consider refactoring your code.";
257 return cursor_.get();
261 void InitializeCursor(
const int32_t num_shards,
const int32_t shard_id) {
262 CAFFE_ENFORCE(num_shards >= 1);
263 CAFFE_ENFORCE(shard_id >= 0);
264 CAFFE_ENFORCE(shard_id < num_shards);
265 num_shards_ = num_shards;
266 shard_id_ = shard_id;
267 cursor_ = db_->NewCursor();
271 void MoveToBeginning()
const {
272 cursor_->SeekToFirst();
273 for (
auto s = 0; s < shard_id_; s++) {
276 cursor_->Valid(),
"Db has less rows than shard id: ", s, shard_id_);
283 unique_ptr<Cursor> cursor_;
284 mutable std::mutex reader_mutex_;
285 uint32_t num_shards_;
300 BlobSerializerBase::SerializationAcceptor acceptor)
override;
305 void Deserialize(
const BlobProto& proto,
Blob* blob)
override;
311 #endif // CAFFE2_CORE_DB_H_ virtual bool Valid()=0
Returns whether the current location is valid - for example, if we have reached the end of the databa...
virtual string key()=0
Returns the current key.
Blob is a general container that hosts a typed pointer.
virtual void Seek(const string &key)=0
Seek to a specific key (or if the key does not exist, seek to the immediate next).
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
An abstract class for the current database transaction while writing.
An abstract class for the cursor of the database while reading.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
A reader wrapper for DB that also allows us to serialize it.
Cursor * cursor() const
Returns the underlying cursor of the db reader.
virtual void SeekToFirst()=0
Seek to the first key in the database.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
An abstract class for accessing a database of key-value pairs.
void SeekToFirst() const
Seeks to the first key.
virtual void Next()=0
Go to the next location in the database.
virtual string value()=0
Returns the current value.
BlobSerializerBase is an abstract class that serializes a blob to a string.