Caffe2 - C++ API
A deep learning, cross platform ML framework
proto_utils.h
1 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_
2 #define CAFFE2_UTILS_PROTO_UTILS_H_
3 
4 #ifdef CAFFE2_USE_LITE_PROTO
5 #include <google/protobuf/message_lite.h>
6 #else // CAFFE2_USE_LITE_PROTO
7 #include <google/protobuf/message.h>
8 #endif // !CAFFE2_USE_LITE_PROTO
9 
10 #include "caffe2/core/logging.h"
11 #include "caffe2/proto/caffe2.pb.h"
12 
13 namespace caffe2 {
14 
15 using std::string;
16 using ::google::protobuf::MessageLite;
17 
18 // A wrapper function to shut down protobuf library (this is needed in ASAN
19 // testing and valgrind cases to avoid protobuf appearing to "leak" memory).
20 void ShutdownProtobufLibrary();
21 
22 // A wrapper function to return device name string for use in blob serialization
23 // / deserialization. This should have one to one correspondence with
24 // caffe2/proto/caffe2.proto: enum DeviceType.
25 //
26 // Note that we can't use DeviceType_Name, because that is only available in
27 // protobuf-full, and some platforms (like mobile) may want to use
28 // protobuf-lite instead.
29 std::string DeviceTypeName(const int32_t& d);
30 
31 // Returns if the two DeviceOptions are pointing to the same device.
32 bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs);
33 
34 // Common interfaces that reads file contents into a string.
35 bool ReadStringFromFile(const char* filename, string* str);
36 bool WriteStringToFile(const string& str, const char* filename);
37 
38 // Common interfaces that are supported by both lite and full protobuf.
39 bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto);
40 inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) {
41  return ReadProtoFromBinaryFile(filename.c_str(), proto);
42 }
43 
44 void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename);
45 inline void WriteProtoToBinaryFile(const MessageLite& proto,
46  const string& filename) {
47  return WriteProtoToBinaryFile(proto, filename.c_str());
48 }
49 
50 #ifdef CAFFE2_USE_LITE_PROTO
51 
52 namespace TextFormat {
53 inline bool ParseFromString(const string& spec, MessageLite* proto) {
54  LOG(FATAL) << "If you are running lite version, you should not be "
55  << "calling any text-format protobuffers.";
56 }
57 } // namespace TextFormat
58 
59 
60 string ProtoDebugString(const MessageLite& proto);
61 
62 bool ParseProtoFromLargeString(const string& str, MessageLite* proto);
63 
64 // Text format MessageLite wrappers: these functions do nothing but just
65 // allowing things to compile. It will produce a runtime error if you are using
66 // MessageLite but still want text support.
67 inline bool ReadProtoFromTextFile(
68  const char* /*filename*/,
69  MessageLite* /*proto*/) {
70  LOG(FATAL) << "If you are running lite version, you should not be "
71  << "calling any text-format protobuffers.";
72  return false; // Just to suppress compiler warning.
73 }
74 inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) {
75  return ReadProtoFromTextFile(filename.c_str(), proto);
76 }
77 
78 inline void WriteProtoToTextFile(
79  const MessageLite& /*proto*/,
80  const char* /*filename*/) {
81  LOG(FATAL) << "If you are running lite version, you should not be "
82  << "calling any text-format protobuffers.";
83 }
84 inline void WriteProtoToTextFile(const MessageLite& proto,
85  const string& filename) {
86  return WriteProtoToTextFile(proto, filename.c_str());
87 }
88 
89 inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) {
90  return (ReadProtoFromBinaryFile(filename, proto) ||
91  ReadProtoFromTextFile(filename, proto));
92 }
93 
94 inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) {
95  return ReadProtoFromFile(filename.c_str(), proto);
96 }
97 
98 #else // CAFFE2_USE_LITE_PROTO
99 
100 using ::google::protobuf::Message;
101 
102 namespace TextFormat {
103 bool ParseFromString(const string& spec, Message* proto);
104 } // namespace TextFormat
105 
106 string ProtoDebugString(const Message& proto);
107 
108 bool ParseProtoFromLargeString(const string& str, Message* proto);
109 
110 bool ReadProtoFromTextFile(const char* filename, Message* proto);
111 inline bool ReadProtoFromTextFile(const string filename, Message* proto) {
112  return ReadProtoFromTextFile(filename.c_str(), proto);
113 }
114 
115 void WriteProtoToTextFile(const Message& proto, const char* filename);
116 inline void WriteProtoToTextFile(const Message& proto, const string& filename) {
117  return WriteProtoToTextFile(proto, filename.c_str());
118 }
119 
120 // Read Proto from a file, letting the code figure out if it is text or binary.
121 inline bool ReadProtoFromFile(const char* filename, Message* proto) {
122  return (ReadProtoFromBinaryFile(filename, proto) ||
123  ReadProtoFromTextFile(filename, proto));
124 }
125 
126 inline bool ReadProtoFromFile(const string& filename, Message* proto) {
127  return ReadProtoFromFile(filename.c_str(), proto);
128 }
129 
130 #endif // CAFFE2_USE_LITE_PROTO
131 
132 template <
133  class IterableInputs = std::initializer_list<string>,
134  class IterableOutputs = std::initializer_list<string>,
135  class IterableArgs = std::initializer_list<Argument>>
136 OperatorDef CreateOperatorDef(
137  const string& type,
138  const string& name,
139  const IterableInputs& inputs,
140  const IterableOutputs& outputs,
141  const IterableArgs& args,
142  const DeviceOption& device_option = DeviceOption(),
143  const string& engine = "") {
144  OperatorDef def;
145  def.set_type(type);
146  def.set_name(name);
147  for (const string& in : inputs) {
148  def.add_input(in);
149  }
150  for (const string& out : outputs) {
151  def.add_output(out);
152  }
153  for (const Argument& arg : args) {
154  def.add_arg()->CopyFrom(arg);
155  }
156  if (device_option.has_device_type()) {
157  def.mutable_device_option()->CopyFrom(device_option);
158  }
159  if (engine.size()) {
160  def.set_engine(engine);
161  }
162  return def;
163 }
164 
165 // A simplified version compared to the full CreateOperator, if you do not need
166 // to specify args.
167 template <
168  class IterableInputs = std::initializer_list<string>,
169  class IterableOutputs = std::initializer_list<string>>
170 inline OperatorDef CreateOperatorDef(
171  const string& type,
172  const string& name,
173  const IterableInputs& inputs,
174  const IterableOutputs& outputs,
175  const DeviceOption& device_option = DeviceOption(),
176  const string& engine = "") {
177  return CreateOperatorDef(
178  type,
179  name,
180  inputs,
181  outputs,
182  std::vector<Argument>(),
183  device_option,
184  engine);
185 }
186 
187 bool HasOutput(const OperatorDef& op, const std::string& output);
188 bool HasInput(const OperatorDef& op, const std::string& input);
189 
199  public:
200  template <typename Def>
201  static bool HasArgument(const Def& def, const string& name) {
202  return ArgumentHelper(def).HasArgument(name);
203  }
204 
205  template <typename Def, typename T>
206  static T GetSingleArgument(
207  const Def& def,
208  const string& name,
209  const T& default_value) {
210  return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
211  }
212 
213  template <typename Def, typename T>
214  static bool HasSingleArgumentOfType(const Def& def, const string& name) {
215  return ArgumentHelper(def).HasSingleArgumentOfType<T>(name);
216  }
217 
218  template <typename Def, typename T>
219  static vector<T> GetRepeatedArgument(
220  const Def& def,
221  const string& name,
222  const std::vector<T>& default_value = std::vector<T>()) {
223  return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
224  }
225 
226  template <typename Def, typename MessageType>
227  static MessageType GetMessageArgument(const Def& def, const string& name) {
228  return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
229  }
230 
231  template <typename Def, typename MessageType>
232  static vector<MessageType> GetRepeatedMessageArgument(
233  const Def& def,
234  const string& name) {
235  return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
236  }
237 
238  explicit ArgumentHelper(const OperatorDef& def);
239  explicit ArgumentHelper(const NetDef& netdef);
240  bool HasArgument(const string& name) const;
241 
242  template <typename T>
243  T GetSingleArgument(const string& name, const T& default_value) const;
244  template <typename T>
245  bool HasSingleArgumentOfType(const string& name) const;
246  template <typename T>
247  vector<T> GetRepeatedArgument(
248  const string& name,
249  const std::vector<T>& default_value = std::vector<T>()) const;
250 
251  template <typename MessageType>
252  MessageType GetMessageArgument(const string& name) const {
253  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
254  MessageType message;
255  if (arg_map_.at(name).has_s()) {
256  CAFFE_ENFORCE(
257  message.ParseFromString(arg_map_.at(name).s()),
258  "Faild to parse content from the string");
259  } else {
260  VLOG(1) << "Return empty message for parameter " << name;
261  }
262  return message;
263  }
264 
265  template <typename MessageType>
266  vector<MessageType> GetRepeatedMessageArgument(const string& name) const {
267  CAFFE_ENFORCE(arg_map_.count(name), "Cannot find parameter named ", name);
268  vector<MessageType> messages(arg_map_.at(name).strings_size());
269  for (int i = 0; i < messages.size(); ++i) {
270  CAFFE_ENFORCE(
271  messages[i].ParseFromString(arg_map_.at(name).strings(i)),
272  "Faild to parse content from the string");
273  }
274  return messages;
275  }
276 
277  private:
278  CaffeMap<string, Argument> arg_map_;
279 };
280 
281 const Argument& GetArgument(const OperatorDef& def, const string& name);
282 bool GetFlagArgument(
283  const OperatorDef& def,
284  const string& name,
285  bool def_value = false);
286 
287 Argument* GetMutableArgument(
288  const string& name,
289  const bool create_if_missing,
290  OperatorDef* def);
291 
292 template <typename T>
293 Argument MakeArgument(const string& name, const T& value);
294 
295 template <typename T>
296 inline void AddArgument(const string& name, const T& value, OperatorDef* def) {
297  GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
298 }
299 
300 bool inline operator==(const DeviceOption& dl, const DeviceOption& dr) {
301  return IsSameDevice(dl, dr);
302 }
303 
304 
305 } // namespace caffe2
306 
307 namespace std {
308 template <>
309 struct hash<caffe2::DeviceOption> {
310  typedef caffe2::DeviceOption argument_type;
311  typedef std::size_t result_type;
312  result_type operator()(argument_type const& device_option) const {
313  std::string serialized;
314  CAFFE_ENFORCE(device_option.SerializeToString(&serialized));
315  return std::hash<std::string>{}(serialized);
316  }
317 };
318 } // namespace std
319 
320 #endif // CAFFE2_UTILS_PROTO_UTILS_H_
Definition: types.h:72
A helper class to index into arguments.
Definition: proto_utils.h:198
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...