Caffe2 - C++ API
A deep learning, cross platform ML framework
make_cifar_db.cc
1 
17 //
18 // This script converts the CIFAR dataset to the leveldb format used
19 // by caffe to perform classification.
20 // Usage:
21 // convert_cifar_data input_folder output_db_file
22 // The CIFAR dataset could be downloaded at
23 // http://www.cs.toronto.edu/~kriz/cifar.html
24 
25 #include <array>
26 #include <fstream> // NOLINT(readability/streams)
27 #include <sstream>
28 #include <string>
29 
30 #include "caffe2/core/common.h"
31 #include "caffe2/core/db.h"
32 #include "caffe2/core/init.h"
33 #include "caffe2/proto/caffe2.pb.h"
34 #include "caffe2/core/logging.h"
35 
36 CAFFE2_DEFINE_string(input_folder, "", "The input folder name.");
37 CAFFE2_DEFINE_string(output_train_db_name,
38  "", "The output training db name.");
39 CAFFE2_DEFINE_string(output_test_db_name,
40  "", "The output testing db name.");
41 CAFFE2_DEFINE_string(db, "leveldb", "The db type.");
42 CAFFE2_DEFINE_bool(is_cifar100, false,
43  "If set, convert cifar100. Otherwise do cifar10.");
44 
45 namespace caffe2 {
46 
47 using std::stringstream;
48 
49 const int kCIFARSize = 32;
50 const int kCIFARImageNBytes = kCIFARSize * kCIFARSize * 3;
51 const int kCIFAR10BatchSize = 10000;
52 const int kCIFAR10TestDataSize = 10000;
53 const int kCIFAR10TrainBatches = 5;
54 
55 const int kCIFAR100TrainDataSize = 50000;
56 const int kCIFAR100TestDataSize = 10000;
57 
58 void ReadImage(std::ifstream* file, int* label, char* buffer) {
59  char label_char;
60  if (caffe2::FLAGS_is_cifar100) {
61  // Skip the coarse label.
62  file->read(&label_char, 1);
63  }
64  file->read(&label_char, 1);
65  *label = label_char;
66  // Yes, there are better ways to do it, like in-place swap... but I am too
67  // lazy so let's just write it in a memory-wasteful way.
68  std::array<char, kCIFARImageNBytes> channel_first_storage;
69  file->read(channel_first_storage.data(), kCIFARImageNBytes);
70  for (int c = 0; c < 3; ++c) {
71  for (int i = 0; i < kCIFARSize * kCIFARSize; ++i) {
72  buffer[i * 3 + c] =
73  channel_first_storage[c * kCIFARSize * kCIFARSize + i];
74  }
75  }
76  return;
77 }
78 
79 void WriteToDB(const string& filename, const int num_items,
80  const int& offset, db::DB* db) {
81  TensorProtos protos;
82  TensorProto* data = protos.add_protos();
83  TensorProto* label = protos.add_protos();
84  data->set_data_type(TensorProto::BYTE);
85  data->add_dims(kCIFARSize);
86  data->add_dims(kCIFARSize);
87  data->add_dims(3);
88  label->set_data_type(TensorProto::INT32);
89  label->add_dims(1);
90  label->add_int32_data(0);
91 
92  LOG(INFO) << "Converting file " << filename;
93  std::ifstream data_file(filename.c_str(),
94  std::ios::in | std::ios::binary);
95  CAFFE_ENFORCE(data_file, "Unable to open file ", filename);
96  char str_buffer[kCIFARImageNBytes];
97  int label_value;
98  string serialized_protos;
99  std::unique_ptr<db::Transaction> transaction(db->NewTransaction());
100  for (int itemid = 0; itemid < num_items; ++itemid) {
101  ReadImage(&data_file, &label_value, str_buffer);
102  data->set_byte_data(str_buffer, kCIFARImageNBytes);
103  label->set_int32_data(0, label_value);
104  protos.SerializeToString(&serialized_protos);
105  snprintf(str_buffer, kCIFARImageNBytes, "%05d",
106  offset + itemid);
107  transaction->Put(string(str_buffer), serialized_protos);
108  }
109 }
110 
111 void ConvertCIFAR() {
112  std::unique_ptr<db::DB> train_db(
113  db::CreateDB(caffe2::FLAGS_db, caffe2::FLAGS_output_train_db_name,
114  db::NEW));
115  std::unique_ptr<db::DB> test_db(
116  db::CreateDB(caffe2::FLAGS_db, caffe2::FLAGS_output_test_db_name,
117  db::NEW));
118 
119  if (!caffe2::FLAGS_is_cifar100) {
120  // This is cifar 10.
121  for (int fileid = 0; fileid < kCIFAR10TrainBatches; ++fileid) {
122  stringstream train_file;
123  train_file << caffe2::FLAGS_input_folder << "/data_batch_" << fileid + 1
124  << ".bin";
125  WriteToDB(train_file.str(), kCIFAR10BatchSize,
126  fileid * kCIFAR10BatchSize, train_db.get());
127  }
128  stringstream test_file;
129  test_file << caffe2::FLAGS_input_folder << "/test_batch.bin";
130  WriteToDB(test_file.str(), kCIFAR10TestDataSize, 0, test_db.get());
131  } else {
132  // This is cifar 100.
133  stringstream train_file;
134  train_file << caffe2::FLAGS_input_folder << "/train.bin";
135  WriteToDB(train_file.str(), kCIFAR100TrainDataSize, 0, train_db.get());
136  stringstream test_file;
137  test_file << caffe2::FLAGS_input_folder << "/test.bin";
138  WriteToDB(test_file.str(), kCIFAR100TestDataSize, 0, test_db.get());
139  }
140 }
141 
142 } // namespace caffe2
143 
144 int main(int argc, char** argv) {
145  caffe2::GlobalInit(&argc, &argv);
146  caffe2::ConvertCIFAR();
147  return 0;
148 }
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
Definition: init.cc:18
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...