Caffe2 - C++ API
A deep learning, cross platform ML framework
load_save_op.h
1 #ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
2 #define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
3 
4 #include <cstdio>
5 #include <map>
6 #include <unordered_set>
7 
8 #include "caffe2/core/blob_serialization.h"
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/db.h"
11 #include "caffe2/core/logging.h"
12 #include "caffe2/core/operator.h"
13 #include "caffe2/utils/math.h"
14 #include "caffe2/utils/proto_utils.h"
15 
16 namespace caffe2 {
17 
18 namespace {
19 struct BlobState {
20  int64_t total_size;
21  int64_t current_size;
22  bool is_tensor;
23  std::set<int32_t> seen_chunks_ids;
24 
25  explicit BlobState(
26  int64_t total_size = 0,
27  int64_t current_size = 0,
28  bool is_tensor = false)
29  : total_size(total_size),
30  current_size(current_size),
31  is_tensor(is_tensor) {}
32 };
33 } // namespace
34 
35 using db::Cursor;
36 using db::DB;
37 using db::Transaction;
38 
39 template <class Context>
40 class DBExistsOp final : public Operator<Context> {
41  public:
42  USE_OPERATOR_CONTEXT_FUNCTIONS;
43  DBExistsOp(const OperatorDef& operator_def, Workspace* ws)
44  : Operator<Context>(operator_def, ws),
45  ws_(ws),
46  absolute_path_(
47  OperatorBase::GetSingleArgument<int>("absolute_path", false)),
48  db_name_(OperatorBase::GetSingleArgument<string>("db_name", "")),
49  db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")) {}
50 
51  bool RunOnDevice() override {
52  string full_db_name =
53  absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
54  auto* output = Output(0);
55  output->Resize();
56  bool* exists = output->template mutable_data<bool>();
57 
58  *exists = caffe2::db::DBExists(db_type_, full_db_name);
59  return true;
60  }
61 
62  private:
63  Workspace* ws_;
64  bool absolute_path_;
65  std::string db_name_;
66  std::string db_type_;
67 };
68 
69 template <class Context>
70 class LoadOp final : public Operator<Context> {
71  public:
72  USE_OPERATOR_CONTEXT_FUNCTIONS;
73  LoadOp(const OperatorDef& operator_def, Workspace* ws)
74  : Operator<Context>(operator_def, ws),
75  ws_(ws),
76  absolute_path_(
77  OperatorBase::GetSingleArgument<int>("absolute_path", false)),
78  add_prefix_(OperatorBase::GetSingleArgument<string>("add_prefix", "")),
79  strip_prefix_(
80  OperatorBase::GetSingleArgument<string>("strip_prefix", "")),
81  db_name_(OperatorBase::GetSingleArgument<string>("db", "")),
82  db_names_(OperatorBase::GetRepeatedArgument<string>("dbs")),
83  db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")),
84  keep_device_(OperatorBase::GetSingleArgument<int>("keep_device", 0)),
85  load_all_(OperatorBase::GetSingleArgument<int>("load_all", 0)),
86  allow_incomplete_(
87  OperatorBase::GetSingleArgument<bool>("allow_incomplete", false)),
88  blob_names_(
89  OperatorBase::GetRepeatedArgument<string>("source_blob_names")) {
90  if (InputSize() == 0) {
91  CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
92  if (db_names_.empty()) {
93  CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
94  db_names_.push_back(db_name_);
95  db_name_ = "";
96  } else {
97  std::set<std::string> db_name_set;
98  for (const string& db_name : db_names_) {
99  CAFFE_ENFORCE_GT(db_name.size(), 0, "Db name should not be empty.");
100  CAFFE_ENFORCE(
101  db_name_set.insert(db_name).second,
102  "Duplicated db name: ",
103  db_name);
104  }
105  db_name_ = "";
106  }
107  }
108  CAFFE_ENFORCE(blob_names_.empty() || blob_names_.size() == OutputSize(),
109  "Number of output blobs and source_blob_names mismatch.");
110  CAFFE_ENFORCE(blob_names_.empty() || strip_prefix_.empty(),
111  "strip_prefix and source_blob_names are mutually exclusive.");
112  CAFFE_ENFORCE(blob_names_.empty() || !load_all_,
113  "cannot load_all_ while using source_blob_names.");
114  if (!load_all_) {
115  // blob_names_ will be filled with ''source blob names'' in file/db
116  // if argument source_blob_names is not given, then blob_names_ is
117  // inferred from operator output
118  if(blob_names_.empty()) {
119  for (const string& name : operator_def.output()) {
120  blob_names_.push_back(name);
121  }
122  }
123  int idx = 0;
124  std::set<std::string> name_set;
125  for (const string& name : blob_names_) {
126  CAFFE_ENFORCE(name_set.insert(name).second,
127  "Duplicated source blob name: ", name);
128  output_indices_[name] = idx++;
129  }
130  }
131  }
132 
133  void SetCurrentDevice(BlobProto* proto);
134 
135  bool RunOnDevice() override {
136  int total_loaded_blobs = 0;
137  std::unordered_map<string, BlobState> blob_states;
138  if (InputSize() > 0) {
139  for (int i = 0; i < InputSize(); ++i) {
140  const db::DBReader& reader = OperatorBase::Input<db::DBReader>(i);
141  extract(i, reader.cursor(), &blob_states, &total_loaded_blobs);
142  }
143  } else {
144  for (int i = 0; i < db_names_.size(); ++i) {
145  string full_db_name = absolute_path_
146  ? db_names_[i]
147  : (ws_->RootFolder() + "/" + db_names_[i]);
148  std::unique_ptr<DB> in_db(
149  caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::READ));
150  CAFFE_ENFORCE(in_db.get(), "Cannot open db: ", full_db_name);
151  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
152  extract(i, cursor.get(), &blob_states, &total_loaded_blobs);
153  }
154  }
155 
156  validateBlobStates(blob_states);
157  // Loaded all the needed blobs.
158  if (load_all_ || total_loaded_blobs == OutputSize()) {
159  VLOG(1) << "Loaded " << total_loaded_blobs << " blobs fully from db(s)";
160  return true;
161  }
162 
163  // Only loaded a subset of the blobs.
164  if (allow_incomplete_) {
165  VLOG(1) << "Loaded " << total_loaded_blobs << " blobs out of "
166  << OutputSize() << " blobs from db(s).";
167  } else {
168  for (const string& output_name : this->debug_def().output()) {
169  if (blob_states.count(output_name) == 0) {
170  LOG(ERROR) << "Failed to load blob: " << output_name;
171  }
172  }
173  CAFFE_THROW(
174  "Expected to load ",
175  OutputSize(),
176  " blobs, got ",
177  total_loaded_blobs,
178  " only.\n");
179  }
180 
181  return true;
182  }
183 
184  private:
185  void extract(
186  int db_id,
187  Cursor* cursor,
188  std::unordered_map<string, BlobState>* blob_states,
189  int* total_loaded_blobs) {
190  if (load_all_) {
191  extractAll(db_id, cursor, blob_states, total_loaded_blobs);
192  } else {
193  extractFrom(
194  db_id,
195  cursor,
196  OperatorBase::Outputs(),
197  blob_states,
198  total_loaded_blobs);
199  }
200  }
201 
202  void extractAll(
203  int db_id,
204  Cursor* cursor,
205  std::unordered_map<string, BlobState>* blob_states,
206  int* total_loaded_blobs) {
207  CAFFE_ENFORCE(cursor, "cursor is not valid");
208  int loaded_blobs = 0;
209  for (; cursor->Valid(); cursor->Next()) {
210  const auto key = buildBlobNameFromDbKey(cursor->key());
211  if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
212  CAFFE_THROW("Duplicate Key ", key, " is found!\n");
213  } else {
214  key_to_dbid_[key] = db_id;
215  }
216 
217  BlobProto proto;
218  CAFFE_ENFORCE(
219  proto.ParseFromString(cursor->value()), "Couldn't parse Proto");
220  if (!keep_device_) {
221  // If we are not keeping the device as the one specified in the
222  // proto, we will set the current device.
223  SetCurrentDevice(&proto);
224  }
225  Blob* blob = ws_->CreateBlob(key);
226  ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
227  }
228  *total_loaded_blobs += loaded_blobs;
229  }
230 
231  void extractFrom(
232  int db_id,
233  Cursor* cursor,
234  const vector<Blob*>& outputs,
235  std::unordered_map<string, BlobState>* blob_states,
236  int* total_loaded_blobs) {
237  CAFFE_ENFORCE(cursor);
238  int loaded_blobs = 0;
239  for (; cursor->Valid(); cursor->Next()) {
240  const auto key = buildBlobNameFromDbKey(cursor->key());
241  if (!output_indices_.count(key)) {
242  VLOG(1) << "Key " << key << " not used. Skipping.";
243  } else {
244  if (key_to_dbid_.count(key) && key_to_dbid_[key] != db_id) {
245  CAFFE_THROW("Duplicate Key ", key, " is found!\n");
246  } else {
247  key_to_dbid_[key] = db_id;
248  }
249 
250  VLOG(2) << "Deserializing blob " << key;
251  BlobProto proto;
252  CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
253  if (!keep_device_) {
254  // If we are not keeping the device as the one specified in the
255  // proto, we will set the current device.
256  SetCurrentDevice(&proto);
257  }
258  auto blobIndex = output_indices_[key];
259  Blob* blob = outputs.at(blobIndex);
260  ProcessBlob(blob, proto, blob_states, key, &loaded_blobs);
261 
262  if (*total_loaded_blobs + loaded_blobs == OutputSize()) {
263  break;
264  }
265  }
266  }
267 
268  *total_loaded_blobs += loaded_blobs;
269  }
270 
271  string buildBlobNameFromDbKey(const string& dbKey) {
272  string key = dbKey.substr(0, dbKey.find(kChunkIdSeparator));
273  if (!strip_prefix_.empty()) {
274  auto match_pos = key.find(strip_prefix_);
275  if (match_pos != string::npos) {
276  key = key.substr(match_pos + strip_prefix_.size());
277  }
278  }
279  key = add_prefix_ + key;
280  return key;
281  }
282 
283  private:
284  // We are tracking sizes of already read tensor parts while reading data
285  // chunks. This way we can make sure that all chunks were loaded in the end.
286  void ProcessBlob(
287  Blob* blob,
288  const BlobProto& proto,
289  std::unordered_map<string, BlobState>* blob_states_ptr,
290  const string& key,
291  int* loaded_blobs) {
292  auto& blob_states = *blob_states_ptr;
293  if (blob_states.count(key) == 0) {
294  // We reset the blob so that any existing content is destroyed. This
295  // is to guaranee correct device placement: if we are deserializing
296  // into a TensorCUDA, without explicit Reset we might be loading data
297  // into an existing TensorCUDA that has pre-allocated memory on a
298  // different GPU.
299  blob->Reset();
300  }
301  blob->Deserialize(proto);
302  if (proto.has_content_num_chunks()) {
303  if (!blob_states.count(key)) {
304  blob_states[key] = BlobState(proto.content_num_chunks());
305  }
306  CAFFE_ENFORCE(
307  blob_states[key]
308  .seen_chunks_ids.insert(proto.content_chunk_id())
309  .second,
310  "Chunk with the same id has occured twice for: ",
311  key);
312  CAFFE_ENFORCE(
313  proto.content_chunk_id() >= 0 &&
314  proto.content_chunk_id() < blob_states[key].total_size,
315  "Chunk id has to be not less than 0 and "
316  "less than content_num_chunks for key: ",
317  key);
318  blob_states[key].current_size++;
319  CAFFE_ENFORCE(
320  !blob_states[key].is_tensor,
321  "Proto with content_chunks can not store tensor: ",
322  key);
323  CAFFE_ENFORCE(
324  blob_states[key].current_size <= blob_states[key].total_size,
325  "Found an extra part for an already filled blob: ",
326  key);
327  if (blob_states[key].current_size == blob_states[key].total_size) {
328  (*loaded_blobs)++;
329  }
330  return;
331  }
332  if (!proto.has_tensor()) {
333  // If blob is divided into chunks the field content_chunks has to be set,
334  // otherwise only tensors can be seen multiple times as chunks.
335  CAFFE_ENFORCE(blob_states.count(key) == 0, "Blob duplicated: ", key);
336  blob_states[key] = BlobState();
337  (*loaded_blobs)++;
338  return;
339  }
340  CAFFE_ENFORCE(proto.has_tensor());
341  if (blob_states.count(key)) {
342  CAFFE_ENFORCE(blob_states[key].is_tensor, "Must be tensor ", key);
343  CAFFE_ENFORCE(
344  blob_states[key].current_size < blob_states[key].total_size,
345  "Found an extra part for an already filled tensor: ",
346  key);
347  CAFFE_ENFORCE(
348  proto.tensor().has_segment(),
349  "Partial tensor must have a segment: ",
350  key);
351  blob_states[key].current_size +=
352  proto.tensor().segment().end() - proto.tensor().segment().begin();
353  CAFFE_ENFORCE(
354  blob_states[key].current_size <= blob_states[key].total_size,
355  "Tensor parts are bigger than target size for tensor: ",
356  key);
357  } else {
358  const auto& dims = proto.tensor().dims();
359  int64_t total_size = 1;
360  for (const auto& dim : dims) {
361  total_size *= dim;
362  }
363  auto current_size = total_size;
364  if (proto.tensor().has_segment()) {
365  current_size =
366  proto.tensor().segment().end() - proto.tensor().segment().begin();
367  }
368  blob_states[key] =
369  BlobState(total_size, current_size, true /* is_tensor */);
370  }
371 
372  if (blob_states[key].current_size == blob_states[key].total_size) {
373  (*loaded_blobs)++;
374  }
375  }
376 
377  void validateBlobStates(
378  const std::unordered_map<string, BlobState>& blob_states) {
379  for (const auto& iter : blob_states) {
380  const BlobState& blob_state = iter.second;
381  CAFFE_ENFORCE(
382  blob_state.current_size == blob_state.total_size,
383  "Data size mismatch for blob ",
384  iter.first,
385  ". Expected: ",
386  blob_state.total_size,
387  " Read: ",
388  blob_state.current_size);
389  }
390  }
391 
392  Workspace* ws_;
393  bool absolute_path_;
394  string add_prefix_;
395  string strip_prefix_;
396  string db_name_;
397  std::vector<std::string> db_names_;
398  string db_type_;
399  bool keep_device_;
400  bool load_all_;
401  bool allow_incomplete_;
402  std::map<string, int> output_indices_;
403  std::map<string, int> key_to_dbid_;
404  std::vector<std::string> blob_names_;
405 };
406 
407 template <class Context>
408 class SaveOp final : public Operator<Context> {
409  public:
410  USE_OPERATOR_CONTEXT_FUNCTIONS;
411  SaveOp(const OperatorDef& operator_def, Workspace* ws)
412  : Operator<Context>(operator_def, ws),
413  ws_(ws),
414  absolute_path_(
415  OperatorBase::GetSingleArgument<int>("absolute_path", false)),
416  strip_prefix_(
417  OperatorBase::GetSingleArgument<string>("strip_prefix", "")),
418  db_name_(OperatorBase::GetSingleArgument<string>("db", "")),
419  db_type_(OperatorBase::GetSingleArgument<string>("db_type", "")),
420  blob_names_(
421  OperatorBase::GetRepeatedArgument<string>("blob_name_overrides")) {
422  CAFFE_ENFORCE_GT(db_name_.size(), 0, "Must specify a db name.");
423  CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type.");
424  CAFFE_ENFORCE(
425  blob_names_.empty() ||
426  blob_names_.size() == OperatorBase::Inputs().size(),
427  "Number of blobs and blob_name_overrides mismatch.");
428  CAFFE_ENFORCE(
429  blob_names_.empty() || strip_prefix_.empty(),
430  "strip_prefix and blob_name_overrides are mutually exclusive.");
431 
432  if (blob_names_.empty()) {
433  std::set<std::string> input_names;
434  blob_names_.resize(OperatorBase::Inputs().size());
435  for (int i = 0; i < blob_names_.size(); ++i) {
436  std::string name;
437  if (strip_prefix_.empty()) {
438  name = operator_def.input(i);
439  } else {
440  auto match_pos = operator_def.input(i).find(strip_prefix_);
441  if (match_pos == string::npos) {
442  name = operator_def.input(i);
443  } else {
444  name = operator_def.input(i).substr(
445  match_pos + strip_prefix_.size(), string::npos);
446  }
447  }
448  CAFFE_ENFORCE(
449  input_names.insert(name).second, "Duplicated input: ", name);
450  blob_names_[i] = name;
451  }
452  }
453  }
454 
455  bool RunOnDevice() override {
456  string full_db_name =
457  absolute_path_ ? db_name_ : (ws_->RootFolder() + "/" + db_name_);
458  std::unique_ptr<DB> out_db(
459  caffe2::db::CreateDB(db_type_, full_db_name, caffe2::db::NEW));
460  CAFFE_ENFORCE(out_db.get(), "Cannot open db for writing: ", full_db_name);
461 
462  BlobSerializerBase::SerializationAcceptor acceptor = [&](
463  const std::string& blobName, const std::string& data) {
464  // transaction should take care of locking
465  VLOG(2) << "Sending " << blobName << " blob's data of size "
466  << data.size() << " to db";
467  auto transaction = out_db->NewTransaction();
468  transaction->Put(blobName, data);
469  transaction->Commit();
470  };
471 
472  const vector<const Blob*>& inputs = OperatorBase::Inputs();
473  for (int i = 0; i < inputs.size(); ++i) {
474  inputs[i]->Serialize(blob_names_[i], acceptor);
475  }
476  out_db->Close();
477  return true;
478  }
479 
480  private:
481  Workspace* ws_;
482  bool absolute_path_;
483  string strip_prefix_;
484  string db_name_;
485  string db_type_;
486  std::vector<std::string> blob_names_;
487 };
488 
489 template <typename... Ts>
490 string FormatString(const string& pattern, Ts... values) {
491  // Note(Yangqing): We believe that 1024 is enough, but who are we to assert
492  // that?
493  // As a result, if things go wrong, we'll just throw the towel and quit loud.
494  // Yeah, I know that there is snprintf, but it is not present in *some*
495  // platforms unfortunately.
496  char buffer[1024];
497  int written = sprintf(buffer, pattern.c_str(), values...);
498  if (written < 0 || written + 1 > 1024) {
499  LOG(FATAL) << "FormatString fails: total bytes written " << written;
500  }
501  return string(buffer);
502  /*
503  * The following is the snprintf version that is safe; enable it one day?
504  unsigned int required =
505  std::snprintf(nullptr, 0, pattern.c_str(), values...) + 1;
506  char bytes[required];
507  std::snprintf(bytes, required, pattern.c_str(), values...);
508  return string(bytes);
509  */
510 }
511 
512 // CheckpointOp is a wrapper over a SaveFloatTensorOp that basically allows
513 // flexible naming over iterations.
514 // The file pattern in db_name should be a format string that can be passed into
515 // sprintf with an int argument specifying the current iteration. An example:
516 // "/path/to/my/checkpoint/checkpoint_at_%d.pb"
517 template <class Context>
518 class CheckpointOp final : public Operator<Context> {
519  public:
520  CheckpointOp(const OperatorDef& operator_def, Workspace* ws)
521  : Operator<Context>(operator_def, ws),
522  db_pattern_(OperatorBase::GetSingleArgument<string>("db", "")),
523  every_(OperatorBase::GetSingleArgument<int>("every", 1)),
524  ws_(ws),
525  save_op_def_(operator_def) {
526  CAFFE_ENFORCE_GT(
527  db_pattern_.size(), 0, "Must specify a checkpoint file pattern.");
528  CAFFE_ENFORCE_GT(every_, 0, "Checkpoint interval should be positive.");
529  if (every_ == 1) {
530  // Just issue a warning, but it's totally legal so we don't do anything.
531  LOG(WARNING) << "It seems that we are checkpointting every iteration. "
532  << "Is that intended?";
533  }
534  save_op_def_.set_type("Save");
535  }
536 
537  bool RunOnDevice() override {
538  int64_t iter =
539  OperatorBase::Input<TensorCPU>(0).template data<int64_t>()[0];
540  if (iter % every_ == 0) {
541  GetMutableArgument("db", true, &save_op_def_)
542  ->set_s(FormatString(db_pattern_, iter));
543  SaveOp<Context> sub_op(save_op_def_, ws_);
544  return sub_op.Run();
545  } else {
546  return true;
547  }
548  }
549 
550  private:
551  string db_pattern_;
552  int every_;
553  Workspace* ws_;
554  OperatorDef save_op_def_;
555 };
556 
557 } // namespace caffe2
558 
559 #endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
Cursor * cursor() const
Returns the underlying cursor of the db reader.
Definition: db.h:254
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
T * Reset(T *allocated)
Sets the underlying object to the allocated one.
Definition: blob.h:121
void Deserialize(const string &content)
Deserializes from a string containing either BlobProto or TensorProto.