1 #ifndef CAFFE2_UTILS_PROTO_UTILS_H_ 2 #define CAFFE2_UTILS_PROTO_UTILS_H_ 4 #ifdef CAFFE2_USE_LITE_PROTO 5 #include <google/protobuf/message_lite.h> 6 #else // CAFFE2_USE_LITE_PROTO 7 #include <google/protobuf/message.h> 8 #endif // !CAFFE2_USE_LITE_PROTO 10 #include "caffe2/core/logging.h" 11 #include "caffe2/proto/caffe2.pb.h" 16 using ::google::protobuf::MessageLite;
20 void ShutdownProtobufLibrary();
29 std::string DeviceTypeName(
const int32_t& d);
32 bool IsSameDevice(
const DeviceOption& lhs,
const DeviceOption& rhs);
35 bool ReadStringFromFile(
const char* filename,
string* str);
36 bool WriteStringToFile(
const string& str,
const char* filename);
39 bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto);
40 inline bool ReadProtoFromBinaryFile(
const string filename, MessageLite* proto) {
41 return ReadProtoFromBinaryFile(filename.c_str(), proto);
44 void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename);
45 inline void WriteProtoToBinaryFile(
const MessageLite& proto,
46 const string& filename) {
47 return WriteProtoToBinaryFile(proto, filename.c_str());
50 #ifdef CAFFE2_USE_LITE_PROTO 52 namespace TextFormat {
53 inline bool ParseFromString(
const string& spec, MessageLite* proto) {
54 LOG(FATAL) <<
"If you are running lite version, you should not be " 55 <<
"calling any text-format protobuffers.";
60 string ProtoDebugString(
const MessageLite& proto);
62 bool ParseProtoFromLargeString(
const string& str, MessageLite* proto);
67 inline bool ReadProtoFromTextFile(
70 LOG(FATAL) <<
"If you are running lite version, you should not be " 71 <<
"calling any text-format protobuffers.";
74 inline bool ReadProtoFromTextFile(
const string filename, MessageLite* proto) {
75 return ReadProtoFromTextFile(filename.c_str(), proto);
78 inline void WriteProtoToTextFile(
81 LOG(FATAL) <<
"If you are running lite version, you should not be " 82 <<
"calling any text-format protobuffers.";
84 inline void WriteProtoToTextFile(
const MessageLite& proto,
85 const string& filename) {
86 return WriteProtoToTextFile(proto, filename.c_str());
89 inline bool ReadProtoFromFile(
const char* filename, MessageLite* proto) {
90 return (ReadProtoFromBinaryFile(filename, proto) ||
91 ReadProtoFromTextFile(filename, proto));
94 inline bool ReadProtoFromFile(
const string& filename, MessageLite* proto) {
95 return ReadProtoFromFile(filename.c_str(), proto);
98 #else // CAFFE2_USE_LITE_PROTO 100 using ::google::protobuf::Message;
102 namespace TextFormat {
103 bool ParseFromString(
const string& spec, Message* proto);
106 string ProtoDebugString(
const Message& proto);
108 bool ParseProtoFromLargeString(
const string& str, Message* proto);
110 bool ReadProtoFromTextFile(
const char* filename, Message* proto);
111 inline bool ReadProtoFromTextFile(
const string filename, Message* proto) {
112 return ReadProtoFromTextFile(filename.c_str(), proto);
115 void WriteProtoToTextFile(
const Message& proto,
const char* filename);
116 inline void WriteProtoToTextFile(
const Message& proto,
const string& filename) {
117 return WriteProtoToTextFile(proto, filename.c_str());
121 inline bool ReadProtoFromFile(
const char* filename, Message* proto) {
122 return (ReadProtoFromBinaryFile(filename, proto) ||
123 ReadProtoFromTextFile(filename, proto));
126 inline bool ReadProtoFromFile(
const string& filename, Message* proto) {
127 return ReadProtoFromFile(filename.c_str(), proto);
130 #endif // CAFFE2_USE_LITE_PROTO 133 class IterableInputs = std::initializer_list<string>,
134 class IterableOutputs = std::initializer_list<string>,
135 class IterableArgs = std::initializer_list<Argument>>
136 OperatorDef CreateOperatorDef(
139 const IterableInputs& inputs,
140 const IterableOutputs& outputs,
141 const IterableArgs& args,
142 const DeviceOption& device_option = DeviceOption(),
143 const string& engine =
"") {
147 for (
const string& in : inputs) {
150 for (
const string& out : outputs) {
153 for (
const Argument& arg : args) {
154 def.add_arg()->CopyFrom(arg);
156 if (device_option.has_device_type()) {
157 def.mutable_device_option()->CopyFrom(device_option);
160 def.set_engine(engine);
168 class IterableInputs = std::initializer_list<string>,
169 class IterableOutputs = std::initializer_list<string>>
170 inline OperatorDef CreateOperatorDef(
173 const IterableInputs& inputs,
174 const IterableOutputs& outputs,
175 const DeviceOption& device_option = DeviceOption(),
176 const string& engine =
"") {
177 return CreateOperatorDef(
182 std::vector<Argument>(),
187 bool HasOutput(
const OperatorDef& op,
const std::string& output);
188 bool HasInput(
const OperatorDef& op,
const std::string& input);
200 template <
typename Def>
201 static bool HasArgument(
const Def& def,
const string& name) {
205 template <
typename Def,
typename T>
206 static T GetSingleArgument(
209 const T& default_value) {
210 return ArgumentHelper(def).GetSingleArgument<T>(name, default_value);
213 template <
typename Def,
typename T>
214 static bool HasSingleArgumentOfType(
const Def& def,
const string& name) {
218 template <
typename Def,
typename T>
219 static vector<T> GetRepeatedArgument(
222 const std::vector<T>& default_value = std::vector<T>()) {
223 return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
226 template <
typename Def,
typename MessageType>
227 static MessageType GetMessageArgument(
const Def& def,
const string& name) {
231 template <
typename Def,
typename MessageType>
232 static vector<MessageType> GetRepeatedMessageArgument(
234 const string& name) {
235 return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
239 explicit ArgumentHelper(
const NetDef& netdef);
240 bool HasArgument(
const string& name)
const;
242 template <
typename T>
243 T GetSingleArgument(
const string& name,
const T& default_value)
const;
244 template <
typename T>
245 bool HasSingleArgumentOfType(
const string& name)
const;
246 template <
typename T>
247 vector<T> GetRepeatedArgument(
249 const std::vector<T>& default_value = std::vector<T>())
const;
251 template <
typename MessageType>
252 MessageType GetMessageArgument(
const string& name)
const {
253 CAFFE_ENFORCE(arg_map_.count(name),
"Cannot find parameter named ", name);
255 if (arg_map_.at(name).has_s()) {
257 message.ParseFromString(arg_map_.at(name).s()),
258 "Faild to parse content from the string");
260 VLOG(1) <<
"Return empty message for parameter " << name;
265 template <
typename MessageType>
266 vector<MessageType> GetRepeatedMessageArgument(
const string& name)
const {
267 CAFFE_ENFORCE(arg_map_.count(name),
"Cannot find parameter named ", name);
268 vector<MessageType> messages(arg_map_.at(name).strings_size());
269 for (
int i = 0; i < messages.size(); ++i) {
271 messages[i].ParseFromString(arg_map_.at(name).strings(i)),
272 "Faild to parse content from the string");
278 CaffeMap<string, Argument> arg_map_;
281 const Argument& GetArgument(
const OperatorDef& def,
const string& name);
282 bool GetFlagArgument(
283 const OperatorDef& def,
285 bool def_value =
false);
287 Argument* GetMutableArgument(
289 const bool create_if_missing,
292 template <
typename T>
293 Argument MakeArgument(
const string& name,
const T& value);
295 template <
typename T>
296 inline void AddArgument(
const string& name,
const T& value, OperatorDef* def) {
297 GetMutableArgument(name,
true, def)->CopyFrom(MakeArgument(name, value));
300 bool inline operator==(
const DeviceOption& dl,
const DeviceOption& dr) {
301 return IsSameDevice(dl, dr);
310 typedef caffe2::DeviceOption argument_type;
311 typedef std::size_t result_type;
312 result_type operator()(argument_type
const& device_option)
const {
313 std::string serialized;
314 CAFFE_ENFORCE(device_option.SerializeToString(&serialized));
315 return std::hash<std::string>{}(serialized);
320 #endif // CAFFE2_UTILS_PROTO_UTILS_H_
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...