Caffe2 - C++ API
A deep learning, cross platform ML framework
db_throughput.cc
1 
17 #include <cstdio>
18 #include <thread>
19 #include <vector>
20 
21 #include "caffe2/core/db.h"
22 #include "caffe2/core/init.h"
23 #include "caffe2/core/timer.h"
24 #include "caffe2/core/logging.h"
25 
26 CAFFE2_DEFINE_string(input_db, "", "The input db.");
27 CAFFE2_DEFINE_string(input_db_type, "", "The input db type.");
28 CAFFE2_DEFINE_int(report_interval, 1000, "The report interval.");
29 CAFFE2_DEFINE_int(repeat, 10, "The number to repeat the throughput test.");
30 CAFFE2_DEFINE_bool(use_reader, false, "If true, use the reader interface.");
31 CAFFE2_DEFINE_int(num_read_threads, 1,
32  "The number of concurrent reading threads.");
33 
34 using caffe2::db::Cursor;
35 using caffe2::db::DB;
37 using caffe2::string;
38 
39 void TestThroughputWithDB() {
40  std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
41  caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db, caffe2::db::READ));
42  std::unique_ptr<Cursor> cursor(in_db->NewCursor());
43  for (int iter_id = 0; iter_id < caffe2::FLAGS_repeat; ++iter_id) {
44  caffe2::Timer timer;
45  for (int i = 0; i < caffe2::FLAGS_report_interval; ++i) {
46  string key = cursor->key();
47  string value = cursor->value();
48  //VLOG(1) << "Key " << key;
49  cursor->Next();
50  if (!cursor->Valid()) {
51  cursor->SeekToFirst();
52  }
53  }
54  double elapsed_seconds = timer.Seconds();
55  printf("Iteration %03d, took %4.5f seconds, throughput %f items/sec.\n",
56  iter_id, elapsed_seconds,
57  caffe2::FLAGS_report_interval / elapsed_seconds);
58  }
59 }
60 
61 void TestThroughputWithReaderWorker(const DBReader* reader, int thread_id) {
62  string key, value;
63  for (int iter_id = 0; iter_id < caffe2::FLAGS_repeat; ++iter_id) {
64  caffe2::Timer timer;
65  for (int i = 0; i < caffe2::FLAGS_report_interval; ++i) {
66  reader->Read(&key, &value);
67  }
68  double elapsed_seconds = timer.Seconds();
69  printf("Thread %03d iteration %03d, took %4.5f seconds, "
70  "throughput %f items/sec.\n",
71  thread_id, iter_id, elapsed_seconds,
72  caffe2::FLAGS_report_interval / elapsed_seconds);
73  }
74 }
75 
76 void TestThroughputWithReader() {
77  caffe2::db::DBReader reader(
78  caffe2::FLAGS_input_db_type, caffe2::FLAGS_input_db);
79  std::vector<std::unique_ptr<std::thread>> reading_threads(
80  caffe2::FLAGS_num_read_threads);
81  for (int i = 0; i < reading_threads.size(); ++i) {
82  reading_threads[i].reset(new std::thread(
83  TestThroughputWithReaderWorker, &reader, i));
84  }
85  for (int i = 0; i < reading_threads.size(); ++i) {
86  reading_threads[i]->join();
87  }
88 }
89 
90 int main(int argc, char** argv) {
91  caffe2::GlobalInit(&argc, &argv);
92  if (caffe2::FLAGS_use_reader) {
93  TestThroughputWithReader();
94  } else {
95  TestThroughputWithDB();
96  }
97  return 0;
98 }
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:18
void Read(string *key, string *value) const
Read a set of key and value from the db and move to next.
Definition: db.h:222
An abstract class for the cursor of the database while reading.
Definition: db.h:22
A reader wrapper for DB that also allows us to serialize it.
Definition: db.h:144
float Seconds()
Returns the elapsed time in seconds.
Definition: timer.h:40
An abstract class for accessing a database of key-value pairs.
Definition: db.h:80
A simple timer object for measuring time.
Definition: timer.h:16