1 #ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ 2 #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ 6 #include <unordered_set> 8 #include "caffe2/core/blob_serialization.h" 9 #include "caffe2/core/context.h" 10 #include "caffe2/core/db.h" 11 #include "caffe2/core/logging.h" 12 #include "caffe2/core/operator.h" 13 #include "caffe2/utils/math.h" 14 #include "caffe2/utils/proto_utils.h" 23 std::set<int32_t> seen_chunks_ids;
26 int64_t total_size = 0,
27 int64_t current_size = 0,
28 bool is_tensor =
false)
29 : total_size(total_size),
30 current_size(current_size),
31 is_tensor(is_tensor) {}
37 using db::Transaction;
39 template <
class Context>
42 USE_OPERATOR_CONTEXT_FUNCTIONS;
47 OperatorBase::GetSingleArgument<int>(
"absolute_path",
false)),
48 db_name_(OperatorBase::GetSingleArgument<string>(
"db_name",
"")),
49 db_type_(OperatorBase::GetSingleArgument<string>(
"db_type",
"")) {}
51 bool RunOnDevice()
override {
53 absolute_path_ ? db_name_ : (ws_->RootFolder() +
"/" + db_name_);
54 auto* output = Output(0);
56 bool* exists = output->template mutable_data<bool>();
58 *exists = caffe2::db::DBExists(db_type_, full_db_name);
69 template <
class Context>
72 USE_OPERATOR_CONTEXT_FUNCTIONS;
77 OperatorBase::GetSingleArgument<int>(
"absolute_path",
false)),
78 add_prefix_(OperatorBase::GetSingleArgument<string>(
"add_prefix",
"")),
80 OperatorBase::GetSingleArgument<string>(
"strip_prefix",
"")),
81 db_name_(OperatorBase::GetSingleArgument<string>(
"db",
"")),
82 db_names_(OperatorBase::GetRepeatedArgument<string>(
"dbs")),
83 db_type_(OperatorBase::GetSingleArgument<string>(
"db_type",
"")),
84 keep_device_(OperatorBase::GetSingleArgument<int>(
"keep_device", 0)),
85 load_all_(OperatorBase::GetSingleArgument<int>(
"load_all", 0)),
87 OperatorBase::GetSingleArgument<bool>(
"allow_incomplete",
false)),
89 OperatorBase::GetRepeatedArgument<string>(
"source_blob_names")) {
90 if (InputSize() == 0) {
91 CAFFE_ENFORCE_GT(db_type_.size(), 0,
"Must specify a db type.");
92 if (db_names_.empty()) {
93 CAFFE_ENFORCE_GT(db_name_.size(), 0,
"Must specify a db name.");
94 db_names_.push_back(db_name_);
97 std::set<std::string> db_name_set;
98 for (
const string& db_name : db_names_) {
99 CAFFE_ENFORCE_GT(db_name.size(), 0,
"Db name should not be empty.");
101 db_name_set.insert(db_name).second,
102 "Duplicated db name: ",
108 CAFFE_ENFORCE(blob_names_.empty() || blob_names_.size() == OutputSize(),
109 "Number of output blobs and source_blob_names mismatch.");
110 CAFFE_ENFORCE(blob_names_.empty() || strip_prefix_.empty(),
111 "strip_prefix and source_blob_names are mutually exclusive.");
112 CAFFE_ENFORCE(blob_names_.empty() || !load_all_,
113 "cannot load_all_ while using source_blob_names.");
118 if(blob_names_.empty()) {
119 for (
const string& name : operator_def.output()) {
120 blob_names_.push_back(name);
124 std::set<std::string> name_set;
125 for (
const string& name : blob_names_) {
126 CAFFE_ENFORCE(name_set.insert(name).second,
127 "Duplicated source blob name: ", name);
128 output_indices_[name] = idx++;
133 void SetCurrentDevice(BlobProto* proto);
135 bool RunOnDevice()
override {
136 int total_loaded_blobs = 0;
137 std::unordered_map<string, BlobState> blob_states;
138 if (InputSize() > 0) {
139 for (
int i = 0; i < InputSize(); ++i) {
140 const db::DBReader& reader = OperatorBase::Input<db::DBReader>(i);
141 extract(i, reader.
cursor(), &blob_states, &total_loaded_blobs);
144 for (
int i = 0; i < db_names_.size(); ++i) {
145 string full_db_name = absolute_path_
147 : (ws_->RootFolder() +
"/" + db_names_[i]);
148 std::unique_ptr<DB> in_db(
149 caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ));
150 CAFFE_ENFORCE(in_db.get(),
"Cannot open db: ", full_db_name);
151 std::unique_ptr<Cursor> cursor(in_db->NewCursor());
152 extract(i, cursor.get(), &blob_states, &total_loaded_blobs);
156 validateBlobStates(blob_states);
158 if (load_all_ || total_loaded_blobs == OutputSize()) {
159 VLOG(1) <<
"Loaded " << total_loaded_blobs <<
" blobs fully from db(s)";
164 if (allow_incomplete_) {
165 VLOG(1) <<
"Loaded " << total_loaded_blobs <<
" blobs out of " 166 << OutputSize() <<
" blobs from db(s).";
168 for (
const string& output_name : this->debug_def().output()) {
169 if (blob_states.count(output_name) == 0) {
170 LOG(ERROR) <<
"Failed to load blob: " << output_name;
188 std::unordered_map<string, BlobState>* blob_states,
189 int* total_loaded_blobs) {
191 extractAll(db_id, cursor, blob_states, total_loaded_blobs);
196 OperatorBase::Outputs(),
205 std::unordered_map<string, BlobState>* blob_states,
206 int* total_loaded_blobs) {
207 CAFFE_ENFORCE(cursor,
"cursor is not valid");
208 int loaded_blobs = 0;
209 for (; cursor->Valid(); cursor->Next()) {
210 const auto key = buildBlobNameFromDbKey(cursor->key());
211 if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
212 CAFFE_THROW(
"Duplicate Key ", key,
" is found!\n");
214 key_to_dbid_[key] = db_id;
219 proto.ParseFromString(cursor->value()),
"Couldn't parse Proto");
223 SetCurrentDevice(&proto);
225 Blob* blob = ws_->CreateBlob(key);
226 ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
228 *total_loaded_blobs += loaded_blobs;
234 const vector<Blob*>& outputs,
235 std::unordered_map<string, BlobState>* blob_states,
236 int* total_loaded_blobs) {
237 CAFFE_ENFORCE(cursor);
238 int loaded_blobs = 0;
239 for (; cursor->Valid(); cursor->Next()) {
240 const auto key = buildBlobNameFromDbKey(cursor->key());
241 if (!output_indices_.count(key)) {
242 VLOG(1) <<
"Key " << key <<
" not used. Skipping.";
244 if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
245 CAFFE_THROW(
"Duplicate Key ", key,
" is found!\n");
247 key_to_dbid_[key] = db_id;
250 VLOG(2) <<
"Deserializing blob " << key;
252 CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
256 SetCurrentDevice(&proto);
258 auto blobIndex = output_indices_[key];
259 Blob* blob = outputs.at(blobIndex);
260 ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
262 if (*total_loaded_blobs + loaded_blobs == OutputSize()) {
268 *total_loaded_blobs += loaded_blobs;
271 string buildBlobNameFromDbKey(
const string& dbKey) {
272 string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator));
273 if (!strip_prefix_.empty()) {
274 auto match_pos = key.find(strip_prefix_);
275 if (match_pos != string::npos) {
276 key = key.substr(match_pos + strip_prefix_.size());
279 key = add_prefix_ + key;
288 const BlobProto& proto,
289 std::unordered_map<string, BlobState>* blob_states_ptr,
292 auto& blob_states = *blob_states_ptr;
293 if (blob_states.count(key) == 0) {
302 if (proto.has_content_num_chunks()) {
303 if (!blob_states.count(key)) {
304 blob_states[key] = BlobState(proto.content_num_chunks());
308 .seen_chunks_ids.insert(proto.content_chunk_id())
310 "Chunk with the same id has occured twice for: ",
313 proto.content_chunk_id() >= 0 &&
314 proto.content_chunk_id() < blob_states[key].total_size,
315 "Chunk id has to be not less than 0 and " 316 "less than content_num_chunks for key: ",
318 blob_states[key].current_size++;
320 !blob_states[key].is_tensor,
321 "Proto with content_chunks can not store tensor: ",
324 blob_states[key].current_size <= blob_states[key].total_size,
325 "Found an extra part for an already filled blob: ",
327 if (blob_states[key].current_size == blob_states[key].total_size) {
332 if (!proto.has_tensor()) {
335 CAFFE_ENFORCE(blob_states.count(key) == 0,
"Blob duplicated: ", key);
336 blob_states[key] = BlobState();
340 CAFFE_ENFORCE(proto.has_tensor());
341 if (blob_states.count(key)) {
342 CAFFE_ENFORCE(blob_states[key].is_tensor,
"Must be tensor ", key);
344 blob_states[key].current_size < blob_states[key].total_size,
345 "Found an extra part for an already filled tensor: ",
348 proto.tensor().has_segment(),
349 "Partial tensor must have a segment: ",
351 blob_states[key].current_size +=
352 proto.tensor().segment().end() - proto.tensor().segment().begin();
354 blob_states[key].current_size <= blob_states[key].total_size,
355 "Tensor parts are bigger than target size for tensor: ",
358 const auto& dims = proto.tensor().dims();
359 int64_t total_size = 1;
360 for (
const auto& dim : dims) {
363 auto current_size = total_size;
364 if (proto.tensor().has_segment()) {
366 proto.tensor().segment().end() - proto.tensor().segment().begin();
369 BlobState(total_size, current_size,
true );
372 if (blob_states[key].current_size == blob_states[key].total_size) {
377 void validateBlobStates(
378 const std::unordered_map<string, BlobState>& blob_states) {
379 for (
const auto& iter : blob_states) {
380 const BlobState& blob_state = iter.second;
382 blob_state.current_size == blob_state.total_size,
383 "Data size mismatch for blob ",
386 blob_state.total_size,
388 blob_state.current_size);
395 string strip_prefix_;
397 std::vector<std::string> db_names_;
401 bool allow_incomplete_;
402 std::map<string, int> output_indices_;
403 std::map<string, int> key_to_dbid_;
404 std::vector<std::string> blob_names_;
407 template <
class Context>
410 USE_OPERATOR_CONTEXT_FUNCTIONS;
415 OperatorBase::GetSingleArgument<int>(
"absolute_path",
false)),
417 OperatorBase::GetSingleArgument<string>(
"strip_prefix",
"")),
418 db_name_(OperatorBase::GetSingleArgument<string>(
"db",
"")),
419 db_type_(OperatorBase::GetSingleArgument<string>(
"db_type",
"")),
421 OperatorBase::GetRepeatedArgument<string>(
"blob_name_overrides")) {
422 CAFFE_ENFORCE_GT(db_name_.size(), 0,
"Must specify a db name.");
423 CAFFE_ENFORCE_GT(db_type_.size(), 0,
"Must specify a db type.");
425 blob_names_.empty() ||
426 blob_names_.size() == OperatorBase::Inputs().size(),
427 "Number of blobs and blob_name_overrides mismatch.");
429 blob_names_.empty() || strip_prefix_.empty(),
430 "strip_prefix and blob_name_overrides are mutually exclusive.");
432 if (blob_names_.empty()) {
433 std::set<std::string> input_names;
434 blob_names_.resize(OperatorBase::Inputs().size());
435 for (
int i = 0; i < blob_names_.size(); ++i) {
437 if (strip_prefix_.empty()) {
438 name = operator_def.input(i);
440 auto match_pos = operator_def.input(i).find(strip_prefix_);
441 if (match_pos == string::npos) {
442 name = operator_def.input(i);
444 name = operator_def.input(i).substr(
445 match_pos + strip_prefix_.size(), string::npos);
449 input_names.insert(name).second,
"Duplicated input: ", name);
450 blob_names_[i] = name;
455 bool RunOnDevice()
override {
456 string full_db_name =
457 absolute_path_ ? db_name_ : (ws_->RootFolder() +
"/" + db_name_);
458 std::unique_ptr<DB> out_db(
459 caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW));
460 CAFFE_ENFORCE(out_db.get(),
"Cannot open db for writing: ", full_db_name);
462 BlobSerializerBase::SerializationAcceptor acceptor = [&](
463 const std::string& blobName,
const std::string& data) {
465 VLOG(2) <<
"Sending " << blobName <<
" blob's data of size " 466 << data.size() <<
" to db";
467 auto transaction = out_db->NewTransaction();
468 transaction->Put(blobName, data);
469 transaction->Commit();
472 const vector<const Blob*>& inputs = OperatorBase::Inputs();
473 for (
int i = 0; i < inputs.size(); ++i) {
474 inputs[i]->Serialize(blob_names_[i], acceptor);
483 string strip_prefix_;
486 std::vector<std::string> blob_names_;
489 template <
typename... Ts>
490 string FormatString(
const string& pattern, Ts... values) {
497 int written = sprintf(buffer, pattern.c_str(), values...);
498 if (written < 0 || written + 1 > 1024) {
499 LOG(FATAL) <<
"FormatString fails: total bytes written " << written;
501 return string(buffer);
517 template <
class Context>
522 db_pattern_(OperatorBase::GetSingleArgument<string>(
"db",
"")),
523 every_(OperatorBase::GetSingleArgument<int>(
"every", 1)),
525 save_op_def_(operator_def) {
527 db_pattern_.size(), 0,
"Must specify a checkpoint file pattern.");
528 CAFFE_ENFORCE_GT(every_, 0,
"Checkpoint interval should be positive.");
531 LOG(WARNING) <<
"It seems that we are checkpointting every iteration. " 532 <<
"Is that intended?";
534 save_op_def_.set_type(
"Save");
537 bool RunOnDevice()
override {
539 OperatorBase::Input<TensorCPU>(0).
template data<int64_t>()[0];
540 if (iter % every_ == 0) {
541 GetMutableArgument(
"db",
true, &save_op_def_)
542 ->set_s(FormatString(db_pattern_, iter));
554 OperatorDef save_op_def_;
559 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ Blob is a general container that hosts a typed pointer.
A reader wrapper for DB that also allows us to serialize it.
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * Reset(T *allocated)
Sets the underlying object to the allocated one.
void Deserialize(const string &content)
Deserializes from a string containing either BlobProto or TensorProto.