1 #include "caffe2/utils/proto_utils.h" 7 #include <google/protobuf/io/coded_stream.h> 9 #ifndef CAFFE2_USE_LITE_PROTO 10 #include <google/protobuf/text_format.h> 11 #include <google/protobuf/io/zero_copy_stream_impl.h> 13 #include <google/protobuf/io/zero_copy_stream_impl_lite.h> 14 #endif // !CAFFE2_USE_LITE_PROTO 16 #include "caffe2/core/logging.h" 18 using ::google::protobuf::MessageLite;
25 const ::std::string& GetEmptyStringAlreadyInited() {
26 return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
36 const ::std::string& GetEmptyStringAlreadyInited() {
37 return ::google::protobuf::internal::GetEmptyStringAlreadyInited();
40 void ShutdownProtobufLibrary() {
41 ::google::protobuf::ShutdownProtobufLibrary();
44 std::string DeviceTypeName(
const int32_t& d) {
58 ". If you have recently updated the caffe2.proto file to add a new " 59 "device type, did you forget to update the TensorDeviceTypeName() " 60 "function to reflect such recent changes?");
67 bool IsSameDevice(
const DeviceOption& lhs,
const DeviceOption& rhs) {
69 lhs.device_type() == rhs.device_type() &&
70 lhs.cuda_gpu_id() == rhs.cuda_gpu_id() &&
71 lhs.node_name() == rhs.node_name() &&
72 lhs.numa_node_id() == rhs.numa_node_id());
75 bool ReadStringFromFile(
const char* filename,
string* str) {
76 std::ifstream ifs(filename, std::ios::in);
78 VLOG(1) <<
"File cannot be opened: " << filename
79 <<
" error: " << ifs.rdstate();
82 ifs.seekg(0, std::ios::end);
83 size_t n = ifs.tellg();
86 ifs.read(&(*str)[0], n);
90 bool WriteStringToFile(
const string& str,
const char* filename) {
91 std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
93 VLOG(1) <<
"File cannot be created: " << filename
94 <<
" error: " << ofs.rdstate();
104 #ifdef CAFFE2_USE_LITE_PROTO 109 class IfstreamInputStream :
public ::google::protobuf::io::CopyingInputStream {
111 explicit IfstreamInputStream(
const string& filename)
112 : ifs_(filename.c_str(),
std::ios::in |
std::ios::binary) {}
113 ~IfstreamInputStream() { ifs_.close(); }
115 int Read(
void* buffer,
int size) {
119 ifs_.read(static_cast<char*>(buffer), size);
120 return ifs_.gcount();
128 string ProtoDebugString(
const MessageLite& proto) {
129 return proto.SerializeAsString();
132 bool ParseProtoFromLargeString(
const string& str, MessageLite* proto) {
133 ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
134 ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
136 coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
137 return proto->ParseFromCodedStream(&coded_stream);
140 bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto) {
141 ::google::protobuf::io::CopyingInputStreamAdaptor stream(
142 new IfstreamInputStream(filename));
143 stream.SetOwnsCopyingStream(
true);
146 ::google::protobuf::io::CodedInputStream coded_stream(&stream);
147 coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
148 return proto->ParseFromCodedStream(&coded_stream);
151 void WriteProtoToBinaryFile(
154 LOG(FATAL) <<
"Not implemented yet.";
157 #else // CAFFE2_USE_LITE_PROTO 161 using ::google::protobuf::io::FileInputStream;
162 using ::google::protobuf::io::FileOutputStream;
163 using ::google::protobuf::io::ZeroCopyInputStream;
164 using ::google::protobuf::io::CodedInputStream;
165 using ::google::protobuf::io::ZeroCopyOutputStream;
166 using ::google::protobuf::io::CodedOutputStream;
167 using ::google::protobuf::Message;
169 namespace TextFormat {
170 bool ParseFromString(
const string& spec, Message* proto) {
171 return ::google::protobuf::TextFormat::ParseFromString(spec, proto);
175 string ProtoDebugString(
const Message& proto) {
176 return proto.ShortDebugString();
179 bool ParseProtoFromLargeString(
const string& str, Message* proto) {
180 ::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
181 ::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
183 coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
184 return proto->ParseFromCodedStream(&coded_stream);
187 bool ReadProtoFromTextFile(
const char* filename, Message* proto) {
188 int fd = open(filename, O_RDONLY);
189 CAFFE_ENFORCE_NE(fd, -1,
"File not found: ", filename);
190 FileInputStream* input =
new FileInputStream(fd);
191 bool success = google::protobuf::TextFormat::Parse(input, proto);
197 void WriteProtoToTextFile(
const Message& proto,
const char* filename) {
198 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
199 FileOutputStream* output =
new FileOutputStream(fd);
200 CAFFE_ENFORCE(google::protobuf::TextFormat::Print(proto, output));
205 bool ReadProtoFromBinaryFile(
const char* filename, MessageLite* proto) {
206 #if defined (_MSC_VER) // for MSC compiler binary flag needs to be specified 207 int fd = open(filename, O_RDONLY | O_BINARY);
209 int fd = open(filename, O_RDONLY);
211 CAFFE_ENFORCE_NE(fd, -1,
"File not found: ", filename);
212 std::unique_ptr<ZeroCopyInputStream> raw_input(
new FileInputStream(fd));
213 std::unique_ptr<CodedInputStream> coded_input(
214 new CodedInputStream(raw_input.get()));
216 coded_input->SetTotalBytesLimit(1073741824, 536870912);
217 bool success = proto->ParseFromCodedStream(coded_input.get());
224 void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename) {
225 int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
227 fd, -1,
"File cannot be created: ", filename,
" error number: ", errno);
228 std::unique_ptr<ZeroCopyOutputStream> raw_output(
new FileOutputStream(fd));
229 std::unique_ptr<CodedOutputStream> coded_output(
230 new CodedOutputStream(raw_output.get()));
231 CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
232 coded_output.reset();
237 #endif // CAFFE2_USE_LITE_PROTO 240 ArgumentHelper::ArgumentHelper(
const OperatorDef& def) {
241 for (
auto& arg : def.arg()) {
242 if (arg_map_.count(arg.name())) {
243 if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) {
247 "Found argument of the same name ",
249 "but with different contents.",
250 ProtoDebugString(def));
252 LOG(WARNING) <<
"Duplicated argument name [" << arg.name()
253 <<
"] found in operator def: " 254 << ProtoDebugString(def);
257 arg_map_[arg.name()] = arg;
261 ArgumentHelper::ArgumentHelper(
const NetDef& netdef) {
262 for (
auto& arg : netdef.arg()) {
264 arg_map_.count(arg.name()) == 0,
265 "Duplicated argument name [", arg.name(),
"] found in net def: ",
266 ProtoDebugString(netdef));
267 arg_map_[arg.name()] = arg;
271 bool ArgumentHelper::HasArgument(
const string& name)
const {
272 return arg_map_.count(name);
278 template <
typename InputType,
typename TargetType>
279 bool SupportsLosslessConversion(
const InputType& value) {
280 return static_cast<InputType
>(
static_cast<TargetType
>(value)) == value;
284 bool operator==(
const NetDef& l,
const NetDef& r) {
285 return l.SerializeAsString() == r.SerializeAsString();
288 std::ostream& operator<<(std::ostream& output,
const NetDef& n) {
289 output << n.SerializeAsString();
293 #define INSTANTIATE_GET_SINGLE_ARGUMENT( \ 294 T, fieldname, enforce_lossless_conversion) \ 296 T ArgumentHelper::GetSingleArgument<T>( \ 297 const string& name, const T& default_value) const { \ 298 if (arg_map_.count(name) == 0) { \ 299 VLOG(1) << "Using default parameter value " << default_value \ 300 << " for parameter " << name; \ 301 return default_value; \ 304 arg_map_.at(name).has_##fieldname(), \ 307 " does not have the right field: expected field " #fieldname); \ 308 auto value = arg_map_.at(name).fieldname(); \ 309 if (enforce_lossless_conversion) { \ 310 auto supportsConversion = \ 311 SupportsLosslessConversion<decltype(value), T>(value); \ 313 supportsConversion, \ 318 "cannot be represented correctly in a target type"); \ 320 return static_cast<T>(value); \ 323 bool ArgumentHelper::HasSingleArgumentOfType<T>(const string& name) const { \ 324 if (arg_map_.count(name) == 0) { \ 327 return arg_map_.at(name).has_##fieldname(); \ 330 INSTANTIATE_GET_SINGLE_ARGUMENT(
float, f,
false)
331 INSTANTIATE_GET_SINGLE_ARGUMENT(
double, f, false)
332 INSTANTIATE_GET_SINGLE_ARGUMENT(
bool, i, false)
333 INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
334 INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
335 INSTANTIATE_GET_SINGLE_ARGUMENT(
int, i, true)
336 INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
337 INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
338 INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
339 INSTANTIATE_GET_SINGLE_ARGUMENT(
size_t, i, true)
340 INSTANTIATE_GET_SINGLE_ARGUMENT(
string, s, false)
341 INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
342 #undef INSTANTIATE_GET_SINGLE_ARGUMENT 344 #define INSTANTIATE_GET_REPEATED_ARGUMENT( \ 345 T, fieldname, enforce_lossless_conversion) \ 347 vector<T> ArgumentHelper::GetRepeatedArgument<T>( \ 348 const string& name, const std::vector<T>& default_value) const { \ 349 if (arg_map_.count(name) == 0) { \ 350 return default_value; \ 353 for (const auto& v : arg_map_.at(name).fieldname()) { \ 354 if (enforce_lossless_conversion) { \ 355 auto supportsConversion = \ 356 SupportsLosslessConversion<decltype(v), T>(v); \ 358 supportsConversion, \ 363 "cannot be represented correctly in a target type"); \ 365 values.push_back(static_cast<T>(v)); \ 370 INSTANTIATE_GET_REPEATED_ARGUMENT(
float, floats,
false)
371 INSTANTIATE_GET_REPEATED_ARGUMENT(
double, floats, false)
372 INSTANTIATE_GET_REPEATED_ARGUMENT(
bool, ints, false)
373 INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
374 INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
375 INSTANTIATE_GET_REPEATED_ARGUMENT(
int, ints, true)
376 INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
377 INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
378 INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
379 INSTANTIATE_GET_REPEATED_ARGUMENT(
size_t, ints, true)
380 INSTANTIATE_GET_REPEATED_ARGUMENT(
string, strings, false)
381 INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
382 #undef INSTANTIATE_GET_REPEATED_ARGUMENT 384 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \ 386 Argument MakeArgument(const string& name, const T& value) { \ 388 arg.set_name(name); \ 389 arg.set_##fieldname(value); \ 393 CAFFE2_MAKE_SINGULAR_ARGUMENT(
bool, i)
394 CAFFE2_MAKE_SINGULAR_ARGUMENT(
float, f)
395 CAFFE2_MAKE_SINGULAR_ARGUMENT(
int, i)
396 CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
397 CAFFE2_MAKE_SINGULAR_ARGUMENT(
string, s)
398 #undef CAFFE2_MAKE_SINGULAR_ARGUMENT 401 Argument MakeArgument(
const string& name,
const MessageLite& value) {
404 arg.set_s(value.SerializeAsString());
408 #define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \ 410 Argument MakeArgument(const string& name, const vector<T>& value) { \ 412 arg.set_name(name); \ 413 for (const auto& v : value) { \ 414 arg.add_##fieldname(v); \ 419 CAFFE2_MAKE_REPEATED_ARGUMENT(
float, floats)
420 CAFFE2_MAKE_REPEATED_ARGUMENT(
int, ints)
421 CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
422 CAFFE2_MAKE_REPEATED_ARGUMENT(
string, strings)
423 #undef CAFFE2_MAKE_REPEATED_ARGUMENT 425 bool HasOutput(
const OperatorDef& op,
const std::string& output) {
426 for (
const auto& outp : op.output()) {
427 if (outp == output) {
434 bool HasInput(
const OperatorDef& op,
const std::string& input) {
435 for (
const auto& inp : op.input()) {
443 const Argument& GetArgument(
const OperatorDef& def,
const string& name) {
444 for (
const Argument& arg : def.arg()) {
445 if (arg.name() == name) {
452 " does not exist in operator ",
453 ProtoDebugString(def));
456 bool GetFlagArgument(
457 const OperatorDef& def,
460 for (
const Argument& arg : def.arg()) {
461 if (arg.name() == name) {
463 arg.has_i(),
"Can't parse argument as bool: ", ProtoDebugString(arg));
470 Argument* GetMutableArgument(
472 const bool create_if_missing,
474 for (
int i = 0; i < def->arg_size(); ++i) {
475 if (def->arg(i).name() == name) {
476 return def->mutable_arg(i);
480 if (create_if_missing) {
481 Argument* arg = def->add_arg();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...