24 #include "caffe2/core/common.h" 25 #include "caffe2/core/db.h" 26 #include "caffe2/core/init.h" 27 #include "caffe2/proto/caffe2.pb.h" 28 #include "caffe2/core/logging.h" 30 CAFFE2_DEFINE_string(image_file,
"",
"The input image file name.");
31 CAFFE2_DEFINE_string(label_file,
"",
"The label file name.");
32 CAFFE2_DEFINE_string(output_file,
"",
"The output db name.");
33 CAFFE2_DEFINE_string(db,
"leveldb",
"The db type.");
34 CAFFE2_DEFINE_int(data_limit, -1,
35 "If set, only output this number of data points.");
36 CAFFE2_DEFINE_bool(channel_first,
false,
37 "If set, write the data as channel-first (CHW order) as the old " 41 uint32_t swap_endian(uint32_t val) {
42 val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
43 return (val << 16) | (val >> 16);
46 void convert_dataset(
const char* image_filename,
const char* label_filename,
47 const char* db_path,
const int data_limit) {
49 std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
50 std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
51 CAFFE_ENFORCE(image_file,
"Unable to open file ", image_filename);
52 CAFFE_ENFORCE(label_file,
"Unable to open file ", label_filename);
60 image_file.read(reinterpret_cast<char*>(&magic), 4);
61 magic = swap_endian(magic);
62 if (magic == 529205256) {
64 "It seems that you forgot to unzip the mnist dataset. You should " 65 "first unzip them using e.g. gunzip on Linux.";
67 CAFFE_ENFORCE_EQ(magic, 2051,
"Incorrect image file magic.");
68 label_file.read(reinterpret_cast<char*>(&magic), 4);
69 magic = swap_endian(magic);
70 CAFFE_ENFORCE_EQ(magic, 2049,
"Incorrect label file magic.");
71 image_file.read(reinterpret_cast<char*>(&num_items), 4);
72 num_items = swap_endian(num_items);
73 label_file.read(reinterpret_cast<char*>(&num_labels), 4);
74 num_labels = swap_endian(num_labels);
75 CAFFE_ENFORCE_EQ(num_items, num_labels);
76 image_file.read(reinterpret_cast<char*>(&rows), 4);
77 rows = swap_endian(rows);
78 image_file.read(reinterpret_cast<char*>(&cols), 4);
79 cols = swap_endian(cols);
82 std::unique_ptr<db::DB> mnist_db(db::CreateDB(caffe2::FLAGS_db, db_path, db::NEW));
83 std::unique_ptr<db::Transaction> transaction(mnist_db->NewTransaction());
86 std::vector<char> pixels(rows * cols);
88 const int kMaxKeyLength = 10;
89 char key_cstr[kMaxKeyLength];
93 TensorProto* data = protos.add_protos();
94 TensorProto* label = protos.add_protos();
95 data->set_data_type(TensorProto::BYTE);
96 if (caffe2::FLAGS_channel_first) {
101 data->add_dims(rows);
102 data->add_dims(cols);
105 label->set_data_type(TensorProto::INT32);
106 label->add_int32_data(0);
108 LOG(INFO) <<
"A total of " << num_items <<
" items.";
109 LOG(INFO) <<
"Rows: " << rows <<
" Cols: " << cols;
110 for (
int item_id = 0; item_id < num_items; ++item_id) {
111 image_file.read(pixels.data(), rows * cols);
112 label_file.read(&label_value, 1);
113 for (
int i = 0; i < rows * cols; ++i) {
114 data->set_byte_data(pixels.data(), rows * cols);
116 label->set_int32_data(0, static_cast<int>(label_value));
117 snprintf(key_cstr, kMaxKeyLength,
"%08d", item_id);
118 protos.SerializeToString(&value);
119 string keystr(key_cstr);
122 transaction->Put(keystr, value);
123 if (++count % 1000 == 0) {
124 transaction->Commit();
126 if (data_limit > 0 && count == data_limit) {
127 LOG(INFO) <<
"Reached data limit of " << data_limit <<
", stop.";
134 int main(
int argc,
char** argv) {
136 caffe2::convert_dataset(caffe2::FLAGS_image_file.c_str(), caffe2::FLAGS_label_file.c_str(),
137 caffe2::FLAGS_output_file.c_str(), caffe2::FLAGS_data_limit);
bool GlobalInit(int *pargc, char ***pargv)
Initialize the global environment of caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...