Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx_exporter.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/proto/caffe2.pb.h"
5 #include "onnx/onnx_pb.h"
6 
7 #include <string>
8 #include <unordered_map>
9 #include <vector>
10 
11 namespace caffe2 {
12 namespace onnx {
13 
14 namespace {
15 using ::ONNX_NAMESPACE::AttributeProto;
16 using ::ONNX_NAMESPACE::GraphProto;
17 using ::ONNX_NAMESPACE::ModelProto;
18 using ::ONNX_NAMESPACE::NodeProto;
19 using ::ONNX_NAMESPACE::TensorProto;
20 using ConvertedResult =
21  std::pair<std::vector<NodeProto>, std::vector<TensorProto>>;
22 } // namespace
23 
24 class OnnxExporter {
25  using SpecialOpConverter = ConvertedResult (OnnxExporter::*)(
26  const caffe2::OperatorDef&,
27  const std::unordered_map<std::string, caffe2::TensorShape>&);
28 
29  public:
30  ConvertedResult Caffe2OpToOnnxNodes(
31  const caffe2::OperatorDef& def,
32  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
33 
34  private:
35  ConvertedResult CommonCaffe2OpToOnnxNodes(const caffe2::OperatorDef& def);
36 
37  ConvertedResult CreateConvPoolNodes(
38  const caffe2::OperatorDef& def,
39  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
40 
41  ConvertedResult CreateGemmNodes(
42  const caffe2::OperatorDef& def,
43  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
44 
45  ConvertedResult CreateReshapeNodes(
46  const caffe2::OperatorDef& def,
47  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
48 
49  ConvertedResult CreateSliceNodes(
50  const caffe2::OperatorDef& def,
51  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
52 
53  ConvertedResult CreateChannelShuffleNodes(
54  const caffe2::OperatorDef& def,
55  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
56 
57  ConvertedResult CreateConcatNodes(
58  const caffe2::OperatorDef& def,
59  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
60 
61  ConvertedResult CreateLrnNodes(
62  const caffe2::OperatorDef& def,
63  const std::unordered_map<std::string, caffe2::TensorShape>& shapes);
64 
65  // \brief Check black listed arguemnts where we won't pass down when
66  // converting to ONNX node
67  bool IsBlackListed(const caffe2::Argument& arg);
68 
69  // \brief Convert Caffe2 argument to Onnx attribute
70  void CopyCaffe2ArgToOnnxAttr(
71  AttributeProto* attr,
72  const std::string& op_type,
73  const caffe2::Argument& arg);
74 
75  // LUT getters
76  const std::unordered_map<std::string, std::string>& get_renamed_operators()
77  const;
78  const std::unordered_map<std::string, std::string>& get_renamed_attrs() const;
79  const std::
80  unordered_map<std::string, std::unordered_map<std::string, std::string>>&
81  get_per_op_renamed_attrs() const;
82  const std::unordered_map<std::string, OnnxExporter::SpecialOpConverter>&
83  get_special_operators() const;
84 };
85 } // namespace onnx
86 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...