1 #include "caffe2/core/logging.h" 2 #include "caffe2/core/operator.h" 3 #include "caffe2/onnx/backend.h" 4 #include "caffe2/onnx/device.h" 5 #include "caffe2/onnx/helper.h" 6 #include "caffe2/utils/map_utils.h" 7 #include "caffe2/utils/proto_utils.h" 10 #include "onnx/checker.h" 11 #include "onnx/optimizer/optimize.h" 14 #include "google/protobuf/io/coded_stream.h" 15 #include "google/protobuf/io/zero_copy_stream_impl_lite.h" 20 #include <unordered_map> 21 #include <unordered_set> 28 constexpr
static int kKnownOpsetVersion = 6;
30 bool AlmostEqual(
double a,
double b) {
31 constexpr
static double kEps = 1e-15;
32 return (fabs(a - b) < kEps);
36 bool TryConvertingTensorRawValues(
37 const TensorProto& onnx_tensor,
38 ::google::protobuf::RepeatedField<T>* field) {
39 if (!onnx_tensor.has_raw_data()) {
43 size_t raw_size = onnx_tensor.raw_data().size();
44 CAFFE_ENFORCE_EQ(raw_size %
sizeof(T), 0);
46 size_t num_elements = raw_size /
sizeof(T);
47 const void* src_ptr =
static_cast<const void*
>(onnx_tensor.raw_data().data());
48 field->Resize(num_elements, 0);
49 void* target_ptr =
static_cast<void*
>(field->mutable_data());
50 memcpy(target_ptr, src_ptr, raw_size);
55 bool IsOperator(
const std::string& op_type) {
58 static std::set<std::string>* ops_ =
59 new std::set<std::string>(caffe2::GetRegisteredOperators());
60 return ops_->count(caffe2::OpRegistryKey(op_type,
"DEFAULT"));
63 caffe2::DeviceOption GetDeviceOption(
const Device& onnx_device) {
64 static const std::unordered_map<DeviceType, caffe2::DeviceType> m = {
65 {DeviceType::CPU, caffe2::DeviceType::CPU},
66 {DeviceType::CUDA, caffe2::DeviceType::CUDA}};
67 caffe2::DeviceOption d;
68 d.set_device_type(static_cast<int32_t>(m.at(onnx_device.type)));
69 d.set_cuda_gpu_id(onnx_device.device_id);
74 ModelProto OptimizeOnnx(
const ModelProto& input,
bool init) {
75 std::vector<std::string> passes{
"fuse_consecutive_transposes",
76 "eliminate_nop_transpose",
77 "fuse_transpose_into_gemm"};
80 passes.emplace_back(
"split_init");
82 passes.emplace_back(
"split_predict");
84 return ::ONNX_NAMESPACE::optimization::Optimize(input, passes);
88 template <
class T,
class U>
90 const std::unordered_map<T, U>& map,
92 const U& default_value) {
93 const auto it = map.find(key);
94 if (it == map.end()) {
101 void UpdateNames(
const caffe2::OperatorDef& op) {
102 for (
const auto& n : op.input()) {
103 DummyName::AddName(n);
105 for (
const auto& n : op.output()) {
106 DummyName::AddName(n);
111 caffe2::OperatorDef* c2_op,
112 const std::string& op_type,
113 const std::vector<std::string>& inputs,
114 const std::vector<std::string>& outputs,
115 const std::vector<caffe2::Argument>& args) {
117 c2_op->set_type(op_type);
118 for (
const auto& input : inputs) {
119 c2_op->add_input(input);
121 for (
const auto& output : outputs) {
122 c2_op->add_output(output);
124 for (
const auto& arg : args) {
125 auto* tmp = c2_op->add_arg();
131 caffe2::OperatorDef* c2_op,
132 const std::string& op_type,
133 const std::vector<std::string>& inputs,
134 const std::vector<std::string>& outputs) {
135 std::vector<caffe2::Argument> empty;
136 BuildOperator(c2_op, op_type, inputs, outputs, empty);
139 void CopyOnnxAttrValueToCaffe2Arg(
140 caffe2::Argument* arg,
141 const AttributeProto& attr) {
143 arg->set_f(attr.f());
144 }
else if (attr.has_i()) {
145 arg->set_i(attr.i());
146 }
else if (attr.has_s()) {
147 arg->set_s(attr.s());
148 }
else if (attr.has_t()) {
151 attr.t().SerializeToString(&buffer);
153 }
else if (attr.floats_size()) {
154 arg->mutable_floats()->CopyFrom(attr.floats());
155 }
else if (attr.ints_size()) {
156 arg->mutable_ints()->CopyFrom(attr.ints());
157 }
else if (attr.strings_size()) {
158 arg->mutable_strings()->CopyFrom(attr.strings());
160 CAFFE_THROW(
"Unsupported ONNX attribute: ", attr.name());
165 OnnxAttributes::OnnxAttributes(
const NodeProto& node) {
166 for (
const auto& attr : node.attribute()) {
167 onnx_attrs_.emplace(attr.name(), &attr);
172 int64_t OnnxAttributes::get(
const std::string& key)
const {
174 const auto it = onnx_attrs_.find(key);
175 if (it != onnx_attrs_.end()) {
176 const AttributeProto& attr = *it->second;
183 float OnnxAttributes::get(
const std::string& key)
const {
185 const auto it = onnx_attrs_.find(key);
186 if (it != onnx_attrs_.end()) {
187 const AttributeProto& attr = *it->second;
194 ::google::protobuf::RepeatedPtrField<std::string> OnnxAttributes::get(
195 const std::string& key)
const {
196 ::google::protobuf::RepeatedPtrField<std::string> value;
197 const auto it = onnx_attrs_.find(key);
198 if (it != onnx_attrs_.end()) {
199 const AttributeProto& attr = *it->second;
200 value.CopyFrom(attr.strings());
206 ::google::protobuf::RepeatedField<::google::protobuf::int64>
207 OnnxAttributes::get(
const std::string& key)
const {
208 ::google::protobuf::RepeatedField<::google::protobuf::int64> value;
209 const auto it = onnx_attrs_.find(key);
210 if (it != onnx_attrs_.end()) {
211 const AttributeProto& attr = *it->second;
212 value.CopyFrom(attr.ints());
218 const TensorProto* OnnxAttributes::get(
const std::string& key)
const {
219 const TensorProto* value =
nullptr;
220 const auto it = onnx_attrs_.find(key);
221 if (it != onnx_attrs_.end()) {
222 const AttributeProto& attr = *it->second;
228 ::google::protobuf::RepeatedPtrField<caffe2::Argument>
229 OnnxAttributes::OnnxAttrToCaffe2Arg(
230 std::function<std::string(
const std::string&)> mapper)
const {
231 ::google::protobuf::RepeatedPtrField<caffe2::Argument> args;
232 for (
const auto& kv : onnx_attrs_) {
235 const auto& attr = rewritten_onnx_attrs_.count(kv.first)
236 ? rewritten_onnx_attrs_.at(kv.first)
238 auto* arg = args.Add();
239 arg->set_name(mapper(attr.name()));
240 CopyOnnxAttrValueToCaffe2Arg(arg, attr);
242 for (
const auto& kv : rewritten_onnx_attrs_) {
245 if (!onnx_attrs_.count(kv.first)) {
246 const auto& attr = kv.second;
247 auto* arg = args.Add();
248 arg->set_name(mapper(attr.name()));
249 CopyOnnxAttrValueToCaffe2Arg(arg, attr);
256 const std::unordered_map<std::string, int>&
257 Caffe2Backend::get_broken_operators()
const {
258 const static std::unordered_map<std::string, int> kBrokenOperators{};
259 return kBrokenOperators;
264 const std::unordered_set<std::string>& Caffe2Backend::get_rnn_operators()
266 const static std::unordered_set<std::string> kRNNOperators{
267 "LSTM",
"GRU",
"RNN"};
268 return kRNNOperators;
275 const std::unordered_map<std::string, std::string>&
276 Caffe2Backend::get_renamed_operators()
const {
277 const static std::unordered_map<std::string, std::string> kRenamedOperators{
278 {
"Caffe2ConvTranspose",
"ConvTranspose"},
279 {
"GlobalMaxPool",
"MaxPool"},
280 {
"GlobalAveragePool",
"AveragePool"},
283 {
"BatchNormalization",
"SpatialBN"},
284 {
"InstanceNormalization",
"InstanceNorm"},
285 {
"MatMul",
"BatchMatMul"},
286 {
"Upsample",
"ResizeNearest"},
287 {
"Identity",
"Copy"},
288 {
"InstanceNormalization",
"InstanceNorm"},
292 {
"Unsqueeze",
"ExpandDims"}};
293 return kRenamedOperators;
296 const std::unordered_map<std::string, std::string>&
297 Caffe2Backend::get_renamed_attrs()
const {
298 const static std::unordered_map<std::string, std::string> kRenamedAttrs{
299 {
"kernel_shape",
"kernels"}};
300 return kRenamedAttrs;
304 unordered_map<std::string, std::unordered_map<std::string, std::string>>&
305 Caffe2Backend::get_per_op_renamed_attrs()
const {
307 unordered_map<std::string, std::unordered_map<std::string, std::string>>
308 kPerOpRenamedAttrs = {{
"Squeeze", {{
"axes",
"dims"}}},
309 {
"Unsqueeze", {{
"axes",
"dims"}}},
310 {
"Transpose", {{
"perm",
"axes"}}},
311 {
"Upsample", {{
"mode",
""}}},
312 {
"ConvTranspose", {{
"output_padding",
"adjs"}}},
313 {
"Selu", {{
"gamma",
"scale"}}}};
315 return kPerOpRenamedAttrs;
321 const std::unordered_map<std::string, Caffe2Backend::SpecialOpConverter>&
322 Caffe2Backend::get_special_operators()
const {
324 unordered_map<std::string, Caffe2Backend::SpecialOpConverter>
325 kSpecialOperators = {
326 {
"Constant", &Caffe2Backend::CreateConstant},
327 {
"Conv", &Caffe2Backend::CreateConvPoolOpBase},
328 {
"AveragePool", &Caffe2Backend::CreateConvPoolOpBase},
329 {
"GlobalAveragePool", &Caffe2Backend::CreateConvPoolOpBase},
330 {
"GlobalMaxPool", &Caffe2Backend::CreateConvPoolOpBase},
331 {
"MaxPool", &Caffe2Backend::CreateConvPoolOpBase},
332 {
"Reshape", &Caffe2Backend::CreateReshape},
333 {
"Gather", &Caffe2Backend::CreateGather},
334 {
"Gemm", &Caffe2Backend::CreateGemm},
335 {
"Pad", &Caffe2Backend::CreatePad},
336 {
"Concat", &Caffe2Backend::CreateConcat},
337 {
"LogSoftmax", &Caffe2Backend::CreateLogSoftmax},
338 {
"Slice", &Caffe2Backend::CreateSlice},
339 {
"Reciprocal", &Caffe2Backend::CreateReciprocal},
340 {
"BatchNormalization", &Caffe2Backend::CreateBatchNormalization},
341 {
"MatMul", &Caffe2Backend::CreateMatMul}};
342 return kSpecialOperators;
349 Caffe2Ops Caffe2Backend::CreateConstant(
352 CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1);
355 auto* c2_op = ret.ops.Add();
356 const auto* value = onnx_node->attributes.get<
const TensorProto*>(
"value");
357 BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0));
392 Caffe2Ops Caffe2Backend::CreateConvPoolOpBase(
395 const auto& node = onnx_node->node;
396 auto& attributes = onnx_node->attributes;
397 if (node.op_type().find(
"Global") == 0) {
398 auto* attr = attributes.AddRewrittenAttibute(
"global_pooling");
402 if (attributes.HasAttribute(
"kernel_shape") &&
403 attributes.HasAttribute(
"pads")) {
406 .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>(
410 .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>(
412 if (kernel_shape.size() == pads.size()) {
414 auto* attr = attributes.AddRewrittenAttibute(
"pads");
415 attr->mutable_ints()->CopyFrom(pads);
416 attr->mutable_ints()->MergeFrom(pads);
420 return CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
423 Caffe2Ops Caffe2Backend::CreateReshape(OnnxNode* onnx_node,
int opset_version) {
424 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
425 CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1);
426 auto* op = c2_op.ops.Mutable(0);
427 op->add_output(DummyName::NewDummyName());
432 Caffe2Ops Caffe2Backend::CreateReciprocal(
435 const auto& node = onnx_node->node;
436 if (node.input_size() != 1 || node.output_size() != 1) {
437 CAFFE_THROW(
"Caffe2 Reciprocal should have 1 input and 1 output");
441 auto* c2_op = ret.ops.Add();
443 caffe2::Argument exponent;
444 exponent.set_name(
"exponent");
445 exponent.set_f(-1.0);
446 BuildOperator(c2_op,
"Pow", {node.input(0)}, {node.output(0)}, {exponent});
450 Caffe2Ops Caffe2Backend::CreateGather(OnnxNode* onnx_node,
int opset_version) {
451 const auto& node = onnx_node->node;
452 if (node.input_size() < 2 || node.output_size() < 1) {
453 CAFFE_THROW(
"Caffe2 Gather should have 2 inputs and 1 output");
457 auto* c2_op = ret.ops.Add();
459 std::vector<std::string> inputs;
460 inputs.emplace_back(node.input(0));
461 inputs.emplace_back(node.input(1));
462 std::vector<std::string> outputs;
463 outputs.emplace_back(node.output(0));
465 auto axis = onnx_node->attributes.get<int64_t>(
"axis", 0L);
467 BuildOperator(c2_op,
"Gather", inputs, outputs);
468 }
else if (axis == 1) {
469 BuildOperator(c2_op,
"BatchGather", inputs, outputs);
472 "Caffe2 only supports Gather with axis being 1 or 2, ",
480 Caffe2Ops Caffe2Backend::CreateGemm(OnnxNode* onnx_node,
int opset_version) {
481 const auto& node = onnx_node->node;
482 if (node.input_size() < 3 || node.output_size() < 1) {
483 CAFFE_THROW(
"Caffe2 Gemm should have 3 inputs and 1 output");
487 auto input_a = node.input(0);
488 auto input_b = node.input(1);
489 auto input_c = node.input(2);
490 auto output = node.output(0);
492 auto alpha = onnx_node->attributes.get<
float>(
"alpha", 1.0);
493 auto beta = onnx_node->attributes.get<
float>(
"beta", 1.0);
494 if (!AlmostEqual(alpha, 1)) {
495 auto scaled_a = DummyName::NewDummyName();
496 caffe2::Argument scale;
497 scale.set_name(
"scale");
500 auto* c2_op = ret.ops.Add();
501 BuildOperator(c2_op,
"Scale", {input_a}, {scaled_a}, {scale});
504 if (!AlmostEqual(beta, 1)) {
505 auto scaled_c = DummyName::NewDummyName();
506 caffe2::Argument scale;
507 scale.set_name(
"scale");
510 auto* c2_op = ret.ops.Add();
511 BuildOperator(c2_op,
"Scale", {input_c}, {scaled_c}, {scale});
515 auto trans_a = onnx_node->attributes.get<int64_t>(
"transA", 0L);
516 auto trans_b = onnx_node->attributes.get<int64_t>(
"transB", 0L);
517 auto broadcast = onnx_node->attributes.get<int64_t>(
"broadcast", 0L);
518 if ((!trans_a) && trans_b && broadcast) {
519 auto* c2_op = ret.ops.Add();
520 BuildOperator(c2_op,
"FC", {input_a, input_b, input_c}, {output});
522 auto ab = DummyName::NewDummyName();
523 caffe2::Argument arg_trans_a;
524 arg_trans_a.set_name(
"trans_a");
525 arg_trans_a.set_i(trans_a);
526 caffe2::Argument arg_trans_b;
527 arg_trans_b.set_name(
"trans_b");
528 arg_trans_b.set_i(trans_b);
529 caffe2::Argument arg_broadcast;
530 arg_broadcast.set_name(
"broadcast");
531 arg_broadcast.set_i(broadcast);
533 auto* c2_op = ret.ops.Add();
535 c2_op,
"MatMul", {input_a, input_b}, {ab}, {arg_trans_a, arg_trans_b});
536 c2_op = ret.ops.Add();
537 BuildOperator(c2_op,
"Add", {ab, input_c}, {output}, {arg_broadcast});
543 Caffe2Ops Caffe2Backend::CreatePad(OnnxNode* onnx_node,
int opset_version) {
544 const auto& node = onnx_node->node;
545 auto& attributes = onnx_node->attributes;
546 ::google::protobuf::RepeatedField<::google::protobuf::int64> pads;
547 std::string pad_name = opset_version < 2 ?
"paddings" :
"pads";
549 .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>(
552 std::stringstream ss;
554 for (
const auto& i : pads) {
561 for (
const auto i : pads) {
563 CAFFE_THROW(
"ONNX does not support negative pads in Pad, but get ", str);
569 if (!(pads.size() == 8 &&
570 (pads.Get(0) + pads.Get(1) + pads.Get(4) + pads.Get(5) == 0))) {
572 "Caffe2 only supports padding 2D Tensor, whereas padding is ", str);
576 auto* attr = attributes.AddRewrittenAttibute(pad_name);
577 attr->add_ints(pads.Get(2));
578 attr->add_ints(pads.Get(3));
579 attr->add_ints(pads.Get(6));
580 attr->add_ints(pads.Get(7));
582 return CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
588 Caffe2Ops Caffe2Backend::CreateConcat(OnnxNode* onnx_node,
int opset_version) {
589 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
590 CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1);
591 auto* op = c2_op.ops.Mutable(0);
592 op->add_output(DummyName::NewDummyName());
597 Caffe2Ops Caffe2Backend::CreateLogSoftmax(
600 const auto& node = onnx_node->node;
601 if (node.input_size() < 1 || node.output_size() < 1) {
602 CAFFE_THROW(
"LogSoftmax should have 1 input and 1 output");
604 auto axis = onnx_node->attributes.get<int64_t>(
"axis", 1L);
605 caffe2::Argument arg_axis;
606 arg_axis.set_name(
"axis");
607 arg_axis.set_i(axis);
608 auto softmax_a = DummyName::NewDummyName();
611 auto* c2_op = ret.ops.Add();
612 BuildOperator(c2_op,
"Softmax", {node.input(0)}, {softmax_a}, {arg_axis});
613 c2_op = ret.ops.Add();
614 BuildOperator(c2_op,
"Log", {softmax_a}, {node.output(0)});
619 Caffe2Ops Caffe2Backend::CreateSlice(OnnxNode* onnx_node,
int opset_version) {
620 auto op_tmp = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
621 CAFFE_ENFORCE_EQ(op_tmp.ops.size(), 1);
622 auto* op = op_tmp.ops.Mutable(0);
623 std::unordered_map<std::string, caffe2::Argument*> args;
624 for (
auto& arg : *op->mutable_arg()) {
625 args.emplace(arg.name(), &arg);
628 caffe2::Argument starts_vals;
629 starts_vals.set_name(
"values");
630 auto pos = args.find(
"starts");
631 if (pos != args.end()) {
632 for (
auto i : pos->second->ints()) {
633 starts_vals.add_ints(i);
638 caffe2::Argument ends_vals;
639 ends_vals.set_name(
"values");
640 pos = args.find(
"ends");
641 if (pos != args.end()) {
642 for (
auto i : pos->second->ints()) {
643 ends_vals.add_ints(i < 0 ? i - 1 : i);
648 caffe2::Argument axes_vals;
649 axes_vals.set_name(
"values");
650 pos = args.find(
"axes");
651 if (pos != args.end()) {
652 for (
auto i : pos->second->ints()) {
653 axes_vals.add_ints(i);
657 auto ndim = starts_vals.ints_size();
658 for (int64_t i = 0; i < ndim; ++i) {
659 axes_vals.add_ints(i);
663 CAFFE_ENFORCE_GE(op->input_size(), 1);
664 auto data = op->input(0);
665 auto shape_tensor = DummyName::NewDummyName();
668 auto* c2_op = ret.ops.Add();
669 BuildOperator(c2_op,
"Shape", {data}, {shape_tensor});
671 auto axes_tensor = DummyName::NewDummyName();
672 c2_op = ret.ops.Add();
674 caffe2::Argument shape;
675 shape.set_name(
"shape");
676 shape.add_ints(axes_vals.ints_size());
678 c2_op,
"GivenTensorIntFill", {}, {axes_tensor}, {shape, axes_vals});
681 auto starts_vals_tensor = DummyName::NewDummyName();
682 auto starts_tensor = DummyName::NewDummyName();
683 auto casted_starts_tensor = DummyName::NewDummyName();
684 c2_op = ret.ops.Add();
686 caffe2::Argument shape_starts;
687 shape_starts.set_name(
"shape");
688 shape_starts.add_ints(starts_vals.ints_size());
691 "GivenTensorInt64Fill",
693 {starts_vals_tensor},
694 {shape_starts, starts_vals});
697 caffe2::Argument dtype;
698 dtype.set_name(
"dtype");
699 dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64));
700 caffe2::Argument constant;
701 constant.set_name(
"value");
703 c2_op = ret.ops.Add();
710 c2_op = ret.ops.Add();
714 {starts_tensor, axes_tensor, starts_vals_tensor},
719 to.set_i(static_cast<int64_t>(caffe2::TensorProto::INT32));
720 c2_op = ret.ops.Add();
721 BuildOperator(c2_op,
"Cast", {starts_tensor}, {casted_starts_tensor}, {to});
723 auto ends_vals_tensor = DummyName::NewDummyName();
724 auto ends_tensor = DummyName::NewDummyName();
725 auto casted_ends_tensor = DummyName::NewDummyName();
726 c2_op = ret.ops.Add();
728 caffe2::Argument shape_ends;
729 shape_ends.set_name(
"shape");
730 shape_ends.add_ints(ends_vals.ints_size());
733 "GivenTensorInt64Fill",
736 {shape_ends, ends_vals});
740 c2_op = ret.ops.Add();
742 c2_op,
"ConstantFill", {shape_tensor}, {ends_tensor}, {dtype, constant});
743 c2_op = ret.ops.Add();
747 {ends_tensor, axes_tensor, ends_vals_tensor},
750 c2_op = ret.ops.Add();
751 BuildOperator(c2_op,
"Cast", {ends_tensor}, {casted_ends_tensor}, {to});
754 c2_op = ret.ops.Add();
755 c2_op->CopyFrom(*op);
756 c2_op->mutable_input()->Clear();
757 c2_op->add_input(data);
758 c2_op->add_input(casted_starts_tensor);
759 c2_op->add_input(casted_ends_tensor);
760 c2_op->mutable_arg()->Clear();
761 for (
const auto& kv : args) {
762 c2_op->add_arg()->CopyFrom(*kv.second);
768 Caffe2Ops Caffe2Backend::CreateBatchNormalization(
771 const auto& node = onnx_node->node;
772 if (opset_version < 6) {
773 auto& attributes = onnx_node->attributes;
774 attributes.remove(
"consumed_inputs");
777 return CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
780 Caffe2Ops Caffe2Backend::CreateMatMul(OnnxNode* onnx_node,
int opset_version) {
781 const auto& node = onnx_node->node;
782 if (node.input_size() != 2) {
783 CAFFE_THROW(
"MatMul should have 2 inputs");
786 auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
787 CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1);
788 auto* op = c2_op.ops.Mutable(0);
789 auto* broadcast_arg = op->add_arg();
790 broadcast_arg->set_name(
"broadcast");
791 broadcast_arg->set_i(1);
799 std::unordered_set<std::string> Caffe2Backend::AllNamesInGraph(
800 const GraphProto& graph) {
801 std::unordered_set<std::string> names;
803 for (
const auto& input : graph.input()) {
804 names.emplace(input.name());
806 for (
const auto& output : graph.output()) {
807 names.emplace(output.name());
809 for (
const auto& node : graph.node()) {
810 for (
const auto& n : node.input()) {
813 for (
const auto& n : node.output()) {
831 Caffe2Ops Caffe2Backend::CommonOnnxNodeToCaffe2Ops(
835 auto* c2_op = ret.ops.Add();
837 const auto& node = onnx_node->node;
838 c2_op->mutable_input()->MergeFrom(node.input());
839 c2_op->mutable_output()->MergeFrom(node.output());
840 c2_op->set_name(node.name());
842 const auto onnx_op_type = node.op_type();
843 auto broken_version = caffe2::get_default(
844 get_broken_operators(), onnx_op_type, std::numeric_limits<int>::max());
845 if (broken_version <= opset_version) {
847 "Don't know how to translate op ",
849 " in ONNX operator set v",
851 " (I only support prior to v",
855 caffe2::get_default(get_renamed_operators(), onnx_op_type, onnx_op_type));
856 if (!IsOperator(c2_op->type())) {
858 "Don't know how to translate op ", onnx_op_type);
861 auto mapper = [&,
this](
const std::string& k) {
862 const auto it = get_per_op_renamed_attrs().find(onnx_op_type);
863 if (it != get_per_op_renamed_attrs().end()) {
864 const auto it_op = it->second.find(k);
865 if (it_op != it->second.end()) {
866 return it_op->second;
869 const auto it_global = get_renamed_attrs().find(k);
870 if (it_global != get_renamed_attrs().end()) {
871 return it_global->second;
875 c2_op->mutable_arg()->MergeFrom(
876 onnx_node->attributes.OnnxAttrToCaffe2Arg(mapper));
881 Caffe2Ops Caffe2Backend::ConvertNode(
882 const std::string& node_str,
884 ::google::protobuf::RepeatedPtrField<NodeProto> nodes;
885 auto* n = nodes.Add();
886 ParseProtoFromLargeString(node_str, n);
887 ModelProto init_model;
888 ModelProto pred_model;
889 OnnxNode onnx_node = OnnxNode(nodes.Get(0));
890 return OnnxNodeToCaffe2Ops(init_model, pred_model, &onnx_node, opset_version);
893 Caffe2Ops Caffe2Backend::OnnxNodeToCaffe2Ops(
894 const ModelProto& init_model,
895 const ModelProto& pred_model,
898 if (get_special_operators().count(onnx_node->node.op_type())) {
899 return (this->*get_special_operators().at(onnx_node->node.op_type()))(
900 onnx_node, opset_version);
902 return CommonOnnxNodeToCaffe2Ops(onnx_node, opset_version);
906 void Caffe2Backend::OnnxToCaffe2(
907 caffe2::NetDef* init_net,
908 caffe2::NetDef* pred_net,
909 const ModelProto& onnx_model,
910 const std::string& device,
912 bool include_initializers,
913 const std::vector<Caffe2Ops>& extras) {
914 auto device_option = GetDeviceOption(Device(device));
917 ModelProto init_model = OptimizeOnnx(onnx_model,
true);
918 ModelProto pred_model = OptimizeOnnx(onnx_model,
false);
920 ModelProto init_model = ModelProto();
921 ModelProto pred_model = onnx_model;
922 pred_model.mutable_graph()->mutable_initializer()->Clear();
925 init_net->set_name(onnx_model.graph().name() +
"_init");
926 pred_net->set_name(onnx_model.graph().name() +
"_predict");
929 if (include_initializers) {
930 for (
const auto& tp : onnx_model.graph().initializer()) {
931 auto* c2_op = init_net->add_op();
932 BuildTensorFillingOp(c2_op, tp);
936 auto name_set = AllNamesInGraph(init_model.graph());
937 auto name_set_pred = AllNamesInGraph(pred_model.graph());
938 name_set.insert(name_set_pred.begin(), name_set_pred.end());
939 DummyName::Reset(name_set);
941 size_t idx_extra = 0;
942 auto converter = [&](
const ModelProto& model, caffe2::NetDef* net)
mutable {
943 net->mutable_device_option()->CopyFrom(device_option);
944 for (
const auto& node : model.graph().node()) {
945 auto* init_net_tmp = include_initializers ? init_net : net;
951 if (get_rnn_operators().count(node.op_type())) {
952 if (idx_extra < extras.size()) {
953 const auto& c2ops = extras[idx_extra++];
954 for (
const auto& op : c2ops.init_ops) {
957 init_net_tmp->mutable_op()->MergeFrom(c2ops.init_ops);
958 for (
const auto& op : c2ops.ops) {
961 net->mutable_op()->MergeFrom(c2ops.ops);
962 for (
const auto& input : c2ops.interface_blobs) {
963 DummyName::AddName(input);
965 net->mutable_external_input()->MergeFrom(c2ops.interface_blobs);
968 "Don't know how to convert ",
970 " without enough extra preconverted string");
973 auto onnx_node = OnnxNode(node);
974 auto c2ops = OnnxNodeToCaffe2Ops(
975 init_model, pred_model, &onnx_node, opset_version);
976 init_net_tmp->mutable_op()->MergeFrom(c2ops.init_ops);
977 net->mutable_op()->MergeFrom(c2ops.ops);
978 net->mutable_external_input()->MergeFrom(c2ops.interface_blobs);
982 for (
const auto& value : model.graph().output()) {
983 net->add_external_output(value.name());
985 for (
const auto& value : model.graph().input()) {
986 net->add_external_input(value.name());
990 converter(init_model, init_net);
991 converter(pred_model, pred_net);
994 Caffe2BackendRep* Caffe2Backend::Prepare(
995 const std::string& onnx_model_str,
996 const std::string& device,
997 const std::vector<Caffe2Ops>& extras) {
998 Caffe2BackendRep* rep =
new Caffe2BackendRep();
999 ModelProto onnx_model;
1000 ParseProtoFromLargeString(onnx_model_str, &onnx_model);
1003 ::ONNX_NAMESPACE::checker::check_model(onnx_model);
1006 int opset_version = -1;
1007 for (
const auto& imp : onnx_model.opset_import()) {
1008 if ((!imp.has_domain()) || imp.domain().empty()) {
1009 opset_version = imp.version();
1010 if (opset_version > kKnownOpsetVersion) {
1012 <<
"This version of onnx-caffe2 targets ONNX operator set version " 1013 << kKnownOpsetVersion
1014 <<
", but the model we are trying to import uses version " 1015 << opset_version <<
". We will try to import it anyway, " 1016 <<
"but if the model uses operators which had BC-breaking changes " 1017 "in the intervening versions, import will fail." 1021 std::cout <<
"Unrecognized operator set " << opset_version << std::endl;
1024 if (opset_version < 0) {
1025 if (onnx_model.ir_version() >= 0x00000003) {
1027 "Model with IR version >= 3 did not specify ONNX operator set " 1028 "version (onnx-caffe2 requires it)");
1045 auto& uninitialized_inputs = rep->uninitialized_inputs();
1046 std::unordered_set<std::string> initialized_inputs;
1047 for (
const auto& tp : onnx_model.graph().initializer()) {
1048 initialized_inputs.emplace(tp.name());
1050 for (
const auto& input : onnx_model.graph().input()) {
1051 if (!initialized_inputs.count(input.name())) {
1052 uninitialized_inputs.emplace_back(input.name());
1059 void Caffe2Backend::BuildTensorFillingOp(
1060 caffe2::OperatorDef* c2_op,
1061 const TensorProto& onnx_tensor,
1062 const std::string& name) {
1063 auto fill_name = name.empty() ? onnx_tensor.name() : name;
1064 CAFFE_ENFORCE(!fill_name.empty());
1066 if (onnx_tensor.has_segment()) {
1067 CAFFE_THROW(
"Currently not supporting loading segments.");
1070 auto* c2_values = c2_op->add_arg();
1071 c2_values->set_name(
"values");
1073 if (onnx_tensor.data_type() == TensorProto::FLOAT) {
1074 c2_op->set_type(
"GivenTensorFill");
1075 auto* floats = c2_values->mutable_floats();
1076 if (!TryConvertingTensorRawValues<float>(onnx_tensor, floats)) {
1077 floats->CopyFrom(onnx_tensor.float_data());
1079 }
else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
1080 c2_op->set_type(
"GivenTensorDoubleFill");
1081 ::google::protobuf::RepeatedField<double> tmp;
1082 const ::google::protobuf::RepeatedField<double>* src = &tmp;
1083 if (!TryConvertingTensorRawValues<double>(onnx_tensor, &tmp)) {
1084 src = &onnx_tensor.double_data();
1086 for (
const auto i : *src) {
1087 c2_values->add_floats(i);
1090 }
else if (onnx_tensor.data_type() == TensorProto::INT64) {
1091 c2_op->set_type(
"GivenTensorInt64Fill");
1092 auto* ints = c2_values->mutable_ints();
1093 if (!TryConvertingTensorRawValues<::google::protobuf::int64>(
1094 onnx_tensor, ints)) {
1095 ints->CopyFrom(onnx_tensor.int64_data());
1097 }
else if (onnx_tensor.data_type() == TensorProto::UINT32) {
1098 c2_op->set_type(
"GivenTensorInt64Fill");
1099 ::google::protobuf::RepeatedField<::google::protobuf::uint64> tmp;
1100 const ::google::protobuf::RepeatedField<::google::protobuf::uint64>* src =
1102 if (!TryConvertingTensorRawValues<::google::protobuf::uint64>(
1103 onnx_tensor, &tmp)) {
1104 src = &onnx_tensor.uint64_data();
1106 for (
const auto i : *src) {
1107 c2_values->add_ints(i);
1111 onnx_tensor.data_type() == TensorProto::BOOL ||
1112 onnx_tensor.data_type() == TensorProto::UINT8 ||
1113 onnx_tensor.data_type() == TensorProto::INT8 ||
1114 onnx_tensor.data_type() == TensorProto::UINT16 ||
1115 onnx_tensor.data_type() == TensorProto::INT16 ||
1116 onnx_tensor.data_type() == TensorProto::INT32) {
1118 onnx_tensor.data_type() == TensorProto::BOOL ?
"GivenTensorBoolFill" 1119 :
"GivenTensorIntFill");
1120 ::google::protobuf::RepeatedField<::google::protobuf::int32> tmp;
1121 const ::google::protobuf::RepeatedField<::google::protobuf::int32>* src =
1123 if (!TryConvertingTensorRawValues<::google::protobuf::int32>(
1124 onnx_tensor, &tmp)) {
1125 src = &onnx_tensor.int32_data();
1127 for (
const auto i : *src) {
1128 c2_values->add_ints(i);
1131 }
else if (onnx_tensor.data_type() == TensorProto::STRING) {
1132 c2_op->set_type(
"GivenTensorStringFill");
1133 auto* strings = c2_values->mutable_strings();
1134 strings->CopyFrom(onnx_tensor.string_data());
1137 "unrecognized tensor type: ",
1138 TensorProto::DataType_Name(onnx_tensor.data_type()));
1141 auto* c2_shape = c2_op->add_arg();
1142 c2_shape->set_name(
"shape");
1143 for (
const auto d : onnx_tensor.dims()) {
1144 c2_shape->add_ints(d);
1146 c2_op->add_output(fill_name);
1149 bool Caffe2Backend::SupportOp(
const std::string type)
const {
1150 return get_special_operators().count(type);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...