@@ -153,6 +153,13 @@ class FrontendGenImpl {
153
153
}
154
154
}
155
155
156
+ int64_t GetDomainVersion (const std::string &domain) {
157
+ auto it = opset_map_.find (domain);
158
+ if (it == opset_map_.end ())
159
+ return 0 ;
160
+ return it->second ;
161
+ }
162
+
156
163
void BindOnnxName (const std::string &onnx_name, Value symbol) {
157
164
frontend_symbols_.AddMapping (onnx_name, symbol);
158
165
}
@@ -694,20 +701,21 @@ class FrontendGenImpl {
694
701
}
695
702
696
703
const onnx::OpSchema *GetOpSchema (const onnx::NodeProto &node) {
697
- auto &domain = node.domain ();
698
- auto version_it = opset_map_.find (domain);
699
- if (version_it == opset_map_.end ())
704
+ int64_t version = GetDomainVersion (node.domain ());
705
+ if (version == 0 )
700
706
return nullptr ;
701
- auto version = version_it-> second ;
702
- return onnx::OpSchemaRegistry::Schema ( node.op_type (), version, domain);
707
+ return onnx::OpSchemaRegistry::Schema (
708
+ node.op_type (), version, node. domain () );
703
709
}
704
710
705
711
std::string GetImportVersionOfNode (const onnx::NodeProto &node) {
706
- auto current_opset = opset_map_.find (node.domain ())->second ;
712
+ int64_t version = GetDomainVersion (node.domain ());
713
+ if (version == 0 )
714
+ return " " ;
707
715
708
716
LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : Importing ONNX"
709
717
<< node.op_type () << " (" << node.name () << " )"
710
- << " , Opset: " << current_opset << " \n " );
718
+ << " , Opset: " << version << " \n " );
711
719
712
720
auto opset_list_it = op_dialect_version_map_.find (node.op_type ());
713
721
@@ -723,23 +731,20 @@ class FrontendGenImpl {
723
731
// It is the current opset when onnx-mlir project is started.
724
732
// All opset lower than the last opset should use the last opset(version)
725
733
if (node.domain ().compare (" ai.onnx.ml" ) != 0 &&
726
- current_opset < opset_list.back () &&
727
- current_opset < MINIMUM_SUPPORTED_OPSET)
734
+ version < opset_list.back () && version < MINIMUM_SUPPORTED_OPSET)
728
735
llvm::outs () << " Warning: ONNX " << node.op_type ()
729
- << " in your model is using Opset " << current_opset
736
+ << " in your model is using Opset " << version
730
737
<< " , which is quite old. Please consider regenerating your "
731
738
" model with a newer Opset.\n " ;
732
739
733
740
for (int i = opset_list.size () - 1 ; i > 0 ; i--) {
734
- LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : - testing Opset "
735
- << opset_list[i - 1 ] << " \n " );
736
- if (current_opset < opset_list[i - 1 ]) {
741
+ if (version < opset_list[i - 1 ]) {
737
742
LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : - use Opset "
738
743
<< opset_list[i] << " \n " );
739
744
return " V" + std::to_string (opset_list[i]);
740
745
}
741
746
}
742
- return std::string ( " " ) ;
747
+ return " " ;
743
748
}
744
749
745
750
func::FuncOp CreateFuncOp (
0 commit comments