Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.cc
1 #include "caffe2/utils/proto_utils.h"
2 
3 #include <fcntl.h>
4 #include <cerrno>
5 #include <fstream>
6 
7 #include <google/protobuf/io/coded_stream.h>
8 
9 #ifndef CAFFE2_USE_LITE_PROTO
10 #include <google/protobuf/text_format.h>
11 #include <google/protobuf/io/zero_copy_stream_impl.h>
12 #else
13 #include <google/protobuf/io/zero_copy_stream_impl_lite.h>
14 #endif // !CAFFE2_USE_LITE_PROTO
15 
16 #include "caffe2/core/logging.h"
17 
18 using ::google::protobuf::MessageLite;
19 
20 namespace caffe {
21 
22 // Caffe wrapper functions for protobuf's GetEmptyStringAlreadyInited() function
23 // used to avoid duplicated global variable in the case when protobuf
24 // is built with hidden visibility.
25 const ::std::string& GetEmptyStringAlreadyInited() {
26  return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
27 }
28 
29 } // namespace caffe
30 
31 namespace caffe2 {
32 
33 // Caffe2 wrapper functions for protobuf's GetEmptyStringAlreadyInited() function
34 // used to avoid duplicated global variable in the case when protobuf
35 // is built with hidden visibility.
36 const ::std::string& GetEmptyStringAlreadyInited() {
37  return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
38 }
39 
40 void ShutdownProtobufLibrary() {
41  ::google::protobuf::ShutdownProtobufLibrary();
42 }
43 
44 std::string DeviceTypeName(const int32_t& d) {
45  switch (d) {
46  case CPU:
47  return "CPU";
48  case CUDA:
49  return "CUDA";
50  case OPENGL:
51  return "OPENGL";
52  case MKLDNN:
53  return "MKLDNN";
54  default:
55  CAFFE_THROW(
56  "Unknown device: ",
57  d,
58  ". If you have recently updated the caffe2.proto file to add a new "
59  "device type, did you forget to update the TensorDeviceTypeName() "
60  "function to reflect such recent changes?");
61  // The below code won't run but is needed to suppress some compiler
62  // warnings.
63  return "";
64  }
65 };
66 
67 bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
68  return (
69  lhs.device_type() == rhs.device_type() &&
70  lhs.cuda_gpu_id() == rhs.cuda_gpu_id() &&
71  lhs.node_name() == rhs.node_name() &&
72  lhs.numa_node_id() == rhs.numa_node_id());
73 }
74 
75 bool ReadStringFromFile(const char* filename, string* str) {
76  std::ifstream ifs(filename, std::ios::in);
77  if (!ifs) {
78  VLOG(1) << "File cannot be opened: " << filename
79  << " error: " << ifs.rdstate();
80  return false;
81  }
82  ifs.seekg(0, std::ios::end);
83  size_t n = ifs.tellg();
84  str->resize(n);
85  ifs.seekg(0);
86  ifs.read(&(*str)[0], n);
87  return true;
88 }
89 
90 bool WriteStringToFile(const string& str, const char* filename) {
91  std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
92  if (!ofs.is_open()) {
93  VLOG(1) << "File cannot be created: " << filename
94  << " error: " << ofs.rdstate();
95  return false;
96  }
97  ofs << str;
98  return true;
99 }
100 
101 // IO-specific proto functions: we will deal with the protocol buffer lite and
102 // full versions differently.
103 
104 #ifdef CAFFE2_USE_LITE_PROTO
105 
106 // Lite runtime.
107 
108 namespace {
109 class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
110  public:
111  explicit IfstreamInputStream(const string& filename)
112  : ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
113  ~IfstreamInputStream() { ifs_.close(); }
114 
115  int Read(void* buffer, int size) {
116  if (!ifs_) {
117  return -1;
118  }
119  ifs_.read(static_cast<char*>(buffer), size);
120  return ifs_.gcount();
121  }
122 
123  private:
124  std::ifstream ifs_;
125 };
126 } // namespace
127 
128 string ProtoDebugString(const MessageLite& proto) {
129  return proto.SerializeAsString();
130 }
131 
132 bool ParseProtoFromLargeString(const string& str, MessageLite* proto) {
133  ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
134  ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
135  // Set PlanDef message size limit to 1G.
136  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
137  return proto->ParseFromCodedStream(&coded_stream);
138 }
139 
140 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
141  ::google::protobuf::io::CopyingInputStreamAdaptor stream(
142  new IfstreamInputStream(filename));
143  stream.SetOwnsCopyingStream(true);
144  // Total bytes hard limit / warning limit are set to 1GB and 512MB
145  // respectively.
146  ::google::protobuf::io::CodedInputStream coded_stream(&stream);
147  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
148  return proto->ParseFromCodedStream(&coded_stream);
149 }
150 
151 void WriteProtoToBinaryFile(
152  const MessageLite& /*proto*/,
153  const char* /*filename*/) {
154  LOG(FATAL) << "Not implemented yet.";
155 }
156 
157 #else // CAFFE2_USE_LITE_PROTO
158 
159 // Full protocol buffer.
160 
161 using ::google::protobuf::io::FileInputStream;
162 using ::google::protobuf::io::FileOutputStream;
163 using ::google::protobuf::io::ZeroCopyInputStream;
164 using ::google::protobuf::io::CodedInputStream;
165 using ::google::protobuf::io::ZeroCopyOutputStream;
166 using ::google::protobuf::io::CodedOutputStream;
167 using ::google::protobuf::Message;
168 
169 namespace TextFormat {
170 bool ParseFromString(const string& spec, Message* proto) {
171  return ::google::protobuf::TextFormat::ParseFromString(spec, proto);
172 }
173 } // namespace TextFormat
174 
175 string ProtoDebugString(const Message& proto) {
176  return proto.ShortDebugString();
177 }
178 
179 bool ParseProtoFromLargeString(const string& str, Message* proto) {
180  ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
181  ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
182  // Set PlanDef message size limit to 1G.
183  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
184  return proto->ParseFromCodedStream(&coded_stream);
185 }
186 
187 bool ReadProtoFromTextFile(const char* filename, Message* proto) {
188  int fd = open(filename, O_RDONLY);
189  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
190  FileInputStream* input = new FileInputStream(fd);
191  bool success = google::protobuf::TextFormat::Parse(input, proto);
192  delete input;
193  close(fd);
194  return success;
195 }
196 
197 void WriteProtoToTextFile(const Message& proto, const char* filename) {
198  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
199  FileOutputStream* output = new FileOutputStream(fd);
200  CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
201  delete output;
202  close(fd);
203 }
204 
205 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) {
206 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified
207  int fd = open(filename, O_RDONLY | O_BINARY);
208 #else
209  int fd = open(filename, O_RDONLY);
210 #endif
211  CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
212  std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
213  std::unique_ptr<CodedInputStream> coded_input(
214  new CodedInputStream(raw_input.get()));
215  // A hack to manually allow using very large protocol buffers.
216  coded_input->SetTotalBytesLimit(1073741824, 536870912);
217  bool success = proto->ParseFromCodedStream(coded_input.get());
218  coded_input.reset();
219  raw_input.reset();
220  close(fd);
221  return success;
222 }
223 
224 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) {
225  int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
226  CAFFE_ENFORCE_NE(
227  fd, -1, "File cannot be created: ", filename, " error number: ", errno);
228  std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
229  std::unique_ptr<CodedOutputStream> coded_output(
230  new CodedOutputStream(raw_output.get()));
231  CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
232  coded_output.reset();
233  raw_output.reset();
234  close(fd);
235 }
236 
237 #endif // CAFFE2_USE_LITE_PROTO
238 
239 
240 ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
241  for (auto& arg : def.arg()) {
242  if (arg_map_.count(arg.name())) {
243  if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) {
244  // If there are two arguments of the same name but different contents,
245  // we will throw an error.
246  CAFFE_THROW(
247  "Found argument of the same name ",
248  arg.name(),
249  "but with different contents.",
250  ProtoDebugString(def));
251  } else {
252  LOG(WARNING) << "Duplicated argument name [" << arg.name()
253  << "] found in operator def: "
254  << ProtoDebugString(def);
255  }
256  }
257  arg_map_[arg.name()] = arg;
258  }
259 }
260 
261 ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
262  for (auto& arg : netdef.arg()) {
263  CAFFE_ENFORCE(
264  arg_map_.count(arg.name()) == 0,
265  "Duplicated argument name [", arg.name(), "] found in net def: ",
266  ProtoDebugString(netdef));
267  arg_map_[arg.name()] = arg;
268  }
269 }
270 
271 bool ArgumentHelper::HasArgument(const string& name) const {
272  return arg_map_.count(name);
273 }
274 
275 namespace {
276 // Helper function to verify that conversion between types won't loose any
277 // significant bit.
278 template <typename InputType, typename TargetType>
279 bool SupportsLosslessConversion(const InputType& value) {
280  return static_cast<InputType>(static_cast<TargetType>(value)) == value;
281 }
282 }
283 
284 bool operator==(const NetDef& l, const NetDef& r) {
285  return l.SerializeAsString() == r.SerializeAsString();
286 }
287 
288 std::ostream& operator<<(std::ostream& output, const NetDef& n) {
289  output << n.SerializeAsString();
290  return output;
291 }
292 
293 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \
294  T, fieldname, enforce_lossless_conversion) \
295  template <> \
296  T ArgumentHelper::GetSingleArgument<T>( \
297  const string& name, const T& default_value) const { \
298  if (arg_map_.count(name) == 0) { \
299  VLOG(1) << "Using default parameter value " << default_value \
300  << " for parameter " << name; \
301  return default_value; \
302  } \
303  CAFFE_ENFORCE( \
304  arg_map_.at(name).has_##fieldname(), \
305  "Argument ", \
306  name, \
307  " does not have the right field: expected field " #fieldname); \
308  auto value = arg_map_.at(name).fieldname(); \
309  if (enforce_lossless_conversion) { \
310  auto supportsConversion = \
311  SupportsLosslessConversion<decltype(value), T>(value); \
312  CAFFE_ENFORCE( \
313  supportsConversion, \
314  "Value", \
315  value, \
316  " of argument ", \
317  name, \
318  "cannot be represented correctly in a target type"); \
319  } \
320  return static_cast<T>(value); \
321  } \
322  template <> \
323  bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \
324  if (arg_map_.count(name) == 0) { \
325  return false; \
326  } \
327  return arg_map_.at(name).has_##fieldname(); \
328  }
329 
330 INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
331 INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
332 INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
333 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
334 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
335 INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
336 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
337 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
338 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
339 INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
340 INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
341 INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
342 #undef INSTANTIATE_GET_SINGLE_ARGUMENT
343 
344 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \
345  T, fieldname, enforce_lossless_conversion) \
346  template <> \
347  vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
348  const string& name, const std::vector<T>& default_value) const { \
349  if (arg_map_.count(name) == 0) { \
350  return default_value; \
351  } \
352  vector<T> values; \
353  for (const auto& v : arg_map_.at(name).fieldname()) { \
354  if (enforce_lossless_conversion) { \
355  auto supportsConversion = \
356  SupportsLosslessConversion<decltype(v), T>(v); \
357  CAFFE_ENFORCE( \
358  supportsConversion, \
359  "Value", \
360  v, \
361  " of argument ", \
362  name, \
363  "cannot be represented correctly in a target type"); \
364  } \
365  values.push_back(static_cast<T>(v)); \
366  } \
367  return values; \
368  }
369 
370 INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
371 INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
372 INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
373 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
374 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
375 INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
376 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
377 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
378 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
379 INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
380 INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
381 INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
382 #undef INSTANTIATE_GET_REPEATED_ARGUMENT
383 
384 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
385 template <> \
386 Argument MakeArgument(const string& name, const T& value) { \
387  Argument arg; \
388  arg.set_name(name); \
389  arg.set_##fieldname(value); \
390  return arg; \
391 }
392 
393 CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i)
394 CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f)
395 CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i)
396 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
397 CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
398 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT
399 
400 template <>
401 Argument MakeArgument(const string& name, const MessageLite& value) {
402  Argument arg;
403  arg.set_name(name);
404  arg.set_s(value.SerializeAsString());
405  return arg;
406 }
407 
408 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \
409 template <> \
410 Argument MakeArgument(const string& name, const vector<T>& value) { \
411  Argument arg; \
412  arg.set_name(name); \
413  for (const auto& v : value) { \
414  arg.add_##fieldname(v); \
415  } \
416  return arg; \
417 }
418 
419 CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats)
420 CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints)
421 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
422 CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings)
423 #undef CAFFE2_MAKE_REPEATED_ARGUMENT
424 
425 bool HasOutput(const OperatorDef& op, const std::string& output) {
426  for (const auto& outp : op.output()) {
427  if (outp == output) {
428  return true;
429  }
430  }
431  return false;
432 }
433 
434 bool HasInput(const OperatorDef& op, const std::string& input) {
435  for (const auto& inp : op.input()) {
436  if (inp == input) {
437  return true;
438  }
439  }
440  return false;
441 }
442 
443 const Argument& GetArgument(const OperatorDef& def, const string& name) {
444  for (const Argument& arg : def.arg()) {
445  if (arg.name() == name) {
446  return arg;
447  }
448  }
449  CAFFE_THROW(
450  "Argument named ",
451  name,
452  " does not exist in operator ",
453  ProtoDebugString(def));
454 }
455 
456 bool GetFlagArgument(
457  const OperatorDef& def,
458  const string& name,
459  bool def_value) {
460  for (const Argument& arg : def.arg()) {
461  if (arg.name() == name) {
462  CAFFE_ENFORCE(
463  arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg));
464  return arg.i();
465  }
466  }
467  return def_value;
468 }
469 
470 Argument* GetMutableArgument(
471  const string& name,
472  const bool create_if_missing,
473  OperatorDef* def) {
474  for (int i = 0; i < def->arg_size(); ++i) {
475  if (def->arg(i).name() == name) {
476  return def->mutable_arg(i);
477  }
478  }
479  // If no argument of the right name is found...
480  if (create_if_missing) {
481  Argument* arg = def->add_arg();
482  arg->set_name(name);
483  return arg;
484  } else {
485  return nullptr;
486  }
487 }
488 
489 } // namespace caffe2
Definition: types.h:72
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...