Caffe2 - C++ API
A deep learning, cross platform ML framework
backend.h
1 #pragma once
2 
3 #include "caffe2/onnx/backend_rep.h"
4 #include "caffe2/onnx/device.h"
5 #include "caffe2/proto/caffe2.pb.h"
6 #include "onnx/onnx_pb.h"
7 
8 #include <functional>
9 #include <string>
10 #include <unordered_map>
11 #include <unordered_set>
12 
13 namespace caffe2 {
14 namespace onnx {
15 
16 using ::ONNX_NAMESPACE::AttributeProto;
17 using ::ONNX_NAMESPACE::GraphProto;
18 using ::ONNX_NAMESPACE::ModelProto;
19 using ::ONNX_NAMESPACE::NodeProto;
20 using ::ONNX_NAMESPACE::TensorProto;
21 
22 // \brief This struct holds the converted ops after the onnx->c2 conversion.
23 // Notice that for RNN ops, it may create ops in init_net. Hence we have the
24 // `init_ops` field.
25 struct Caffe2Ops {
26  ::google::protobuf::RepeatedPtrField<caffe2::OperatorDef> init_ops;
27  ::google::protobuf::RepeatedPtrField<caffe2::OperatorDef> ops;
28  ::google::protobuf::RepeatedPtrField<std::string> interface_blobs;
29 };
30 
31 // A convenient class to query attributes of a NodeProto. Note that the
32 // NodeProto can not be modified during the query of OnnxAttributes object
34  public:
35  OnnxAttributes(const NodeProto& node);
36 
37  bool HasAttribute(const std::string& key) const {
38  return onnx_attrs_.count(key);
39  }
40 
41  AttributeProto* AddRewrittenAttibute(const std::string& key) {
42  auto tmp = rewritten_onnx_attrs_.emplace(key, AttributeProto());
43  auto& attr = tmp.first->second;
44  attr.set_name(key);
45  return &attr;
46  }
47 
48  ::google::protobuf::RepeatedPtrField<caffe2::Argument> OnnxAttrToCaffe2Arg(
49  std::function<std::string(const std::string&)> mapper) const;
50 
51  // Get attribute given attribute name, specialied on data type T. Note that
52  // the return value is copied
53  template <typename T>
54  T get(const std::string& key) const;
55 
56  template <typename T>
57  T get(const std::string& key, const T& default_value) const {
58  if (onnx_attrs_.count(key)) {
59  return get<T>(key);
60  } else {
61  return default_value;
62  }
63  }
64 
65  const AttributeProto* remove(const std::string& key) {
66  const AttributeProto* result = nullptr;
67  auto iter = onnx_attrs_.find(key);
68  if (iter != onnx_attrs_.end()) {
69  result = iter->second;
70  onnx_attrs_.erase(iter);
71  }
72  return result;
73  }
74 
75  private:
76  std::unordered_map<std::string, const AttributeProto*> onnx_attrs_;
77  std::unordered_map<std::string, AttributeProto> rewritten_onnx_attrs_;
78 };
79 
80 template <>
81 int64_t OnnxAttributes::get(const std::string& key) const;
82 template <>
83 float OnnxAttributes::get(const std::string& key) const;
84 
85 template <>
86 ::google::protobuf::RepeatedPtrField<std::string> OnnxAttributes::get(
87  const std::string& key) const;
88 
89 template <>
90 ::google::protobuf::RepeatedField<::google::protobuf::int64>
91 OnnxAttributes::get(const std::string& key) const;
92 
93 template <>
94 const TensorProto* OnnxAttributes::get(const std::string& key) const;
95 
96 // convenient class for onnx node
97 struct OnnxNode {
98  OnnxNode(const NodeProto& node_in) : node(node_in), attributes(node_in) {}
99 
100  const NodeProto& node;
101 
102  OnnxAttributes attributes;
103 };
104 
106  public:
107  Caffe2BackendRep* Prepare(
108  const std::string& onnx_model_str,
109  const std::string& device,
110  const std::vector<Caffe2Ops>& extras);
111 
112  bool SupportOp(const std::string tyep) const;
113 
114  Caffe2Ops ConvertNode(const std::string& node_str, int opset_version);
115 
116  private:
117  using SpecialOpConverter = Caffe2Ops (Caffe2Backend::*)(OnnxNode*, int);
118 
119  void OnnxToCaffe2(
120  caffe2::NetDef* init_net,
121  caffe2::NetDef* pred_net,
122  const ModelProto& onnx_model,
123  const std::string& device,
124  int opset_version,
125  bool include_initializers,
126  const std::vector<Caffe2Ops>& extras);
127 
128  Caffe2Ops OnnxNodeToCaffe2Ops(
129  const ModelProto& init_model,
130  const ModelProto& pred_model,
131  OnnxNode* onnx_node,
132  int opset_version);
133 
134  std::unordered_set<std::string> AllNamesInGraph(const GraphProto& graph);
135 
136  void BuildTensorFillingOp(
137  caffe2::OperatorDef* c2_op,
138  const TensorProto& onnx_tensor,
139  const std::string& name = "");
140 
141  Caffe2Ops CommonOnnxNodeToCaffe2Ops(OnnxNode* onnx_node, int opset_version);
142 
143  Caffe2Ops CreateConstant(OnnxNode* onnx_node, int opset_version);
144 
145  Caffe2Ops CreateConvPoolOpBase(OnnxNode* onnx_node, int opset_version);
146 
147  Caffe2Ops CreateReshape(OnnxNode* onnx_node, int opset_version);
148 
149  Caffe2Ops CreateGather(OnnxNode* onnx_node, int opset_version);
150 
151  Caffe2Ops CreateGemm(OnnxNode* onnx_node, int opset_version);
152 
153  Caffe2Ops CreatePad(OnnxNode* onnx_node, int opset_version);
154 
155  Caffe2Ops CreateConcat(OnnxNode* onnx_node, int opset_version);
156 
157  Caffe2Ops CreateLogSoftmax(OnnxNode* onnx_node, int opset_version);
158 
159  Caffe2Ops CreateSlice(OnnxNode* onnx_node, int opset_version);
160 
161  Caffe2Ops CreateReciprocal(OnnxNode* onnx_node, int opset_version);
162 
163  Caffe2Ops CreateBatchNormalization(OnnxNode* onnx_node, int opset_version);
164 
165  Caffe2Ops CreateMatMul(OnnxNode* onnx_node, int opset_version);
166 
167  // LUT related getters
168  const std::unordered_map<std::string, std::string>& get_renamed_operators()
169  const;
170  const std::unordered_set<std::string>& get_rnn_operators() const;
171  const std::unordered_map<std::string, int>& get_broken_operators() const;
172  const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
173  const std::
174  unordered_map<std::string, std::unordered_map<std::string, std::string>>&
175  get_per_op_renamed_attrs() const;
176  const std::unordered_map<std::string, Caffe2Backend::SpecialOpConverter>&
177  get_special_operators() const;
178 };
179 
180 } // namespace onnx
181 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...