3 #include "caffe2/onnx/backend_rep.h" 4 #include "caffe2/onnx/device.h" 5 #include "caffe2/proto/caffe2.pb.h" 6 #include "onnx/onnx_pb.h" 10 #include <unordered_map> 11 #include <unordered_set> 16 using ::ONNX_NAMESPACE::AttributeProto;
17 using ::ONNX_NAMESPACE::GraphProto;
18 using ::ONNX_NAMESPACE::ModelProto;
19 using ::ONNX_NAMESPACE::NodeProto;
20 using ::ONNX_NAMESPACE::TensorProto;
26 ::google::protobuf::RepeatedPtrField<caffe2::OperatorDef> init_ops;
27 ::google::protobuf::RepeatedPtrField<caffe2::OperatorDef> ops;
28 ::google::protobuf::RepeatedPtrField<std::string> interface_blobs;
37 bool HasAttribute(
const std::string& key)
const {
38 return onnx_attrs_.count(key);
41 AttributeProto* AddRewrittenAttibute(
const std::string& key) {
42 auto tmp = rewritten_onnx_attrs_.emplace(key, AttributeProto());
43 auto& attr = tmp.first->second;
48 ::google::protobuf::RepeatedPtrField<caffe2::Argument> OnnxAttrToCaffe2Arg(
49 std::function<std::string(
const std::string&)> mapper)
const;
54 T
get(
const std::string& key)
const;
57 T
get(
const std::string& key,
const T& default_value)
const {
58 if (onnx_attrs_.count(key)) {
65 const AttributeProto*
remove(
const std::string& key) {
66 const AttributeProto* result =
nullptr;
67 auto iter = onnx_attrs_.find(key);
68 if (iter != onnx_attrs_.end()) {
69 result = iter->second;
70 onnx_attrs_.erase(iter);
76 std::unordered_map<std::string, const AttributeProto*> onnx_attrs_;
77 std::unordered_map<std::string, AttributeProto> rewritten_onnx_attrs_;
81 int64_t OnnxAttributes::get(
const std::string& key)
const;
83 float OnnxAttributes::get(
const std::string& key)
const;
86 ::google::protobuf::RepeatedPtrField<std::string> OnnxAttributes::get(
87 const std::string& key)
const;
90 ::google::protobuf::RepeatedField<::google::protobuf::int64>
91 OnnxAttributes::get(
const std::string& key)
const;
94 const TensorProto* OnnxAttributes::get(
const std::string& key)
const;
98 OnnxNode(
const NodeProto& node_in) : node(node_in), attributes(node_in) {}
100 const NodeProto& node;
108 const std::string& onnx_model_str,
109 const std::string& device,
110 const std::vector<Caffe2Ops>& extras);
112 bool SupportOp(
const std::string tyep)
const;
114 Caffe2Ops ConvertNode(
const std::string& node_str,
int opset_version);
120 caffe2::NetDef* init_net,
121 caffe2::NetDef* pred_net,
122 const ModelProto& onnx_model,
123 const std::string& device,
125 bool include_initializers,
126 const std::vector<Caffe2Ops>& extras);
129 const ModelProto& init_model,
130 const ModelProto& pred_model,
134 std::unordered_set<std::string> AllNamesInGraph(
const GraphProto& graph);
136 void BuildTensorFillingOp(
137 caffe2::OperatorDef* c2_op,
138 const TensorProto& onnx_tensor,
139 const std::string& name =
"");
168 const std::unordered_map<std::string, std::string>& get_renamed_operators()
170 const std::unordered_set<std::string>& get_rnn_operators()
const;
171 const std::unordered_map<std::string, int>& get_broken_operators()
const;
172 const std::unordered_map<std::string, std::string>& get_renamed_attrs()
const;
174 unordered_map<std::string, std::unordered_map<std::string, std::string>>&
175 get_per_op_renamed_attrs()
const;
176 const std::unordered_map<std::string, Caffe2Backend::SpecialOpConverter>&
177 get_special_operators()
const;
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...