@@ -102,18 +102,25 @@ class FrontendGenImpl {
102
102
ModuleOp module_;
103
103
OpBuilder builder_;
104
104
105
- // onnxop: list of versions for dialect
106
- std::unordered_map<std::string, std::vector<int >> op_dialect_version_map_;
105
+ using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);
106
+
107
+ struct VersionedHandler {
108
+ int version;
109
+ ImportHandlerType handler;
110
+ };
111
+
112
+ using ONNXOpVersions = SmallVector<VersionedHandler, 1 >;
113
+
114
+ // Maps NodeProto::op_type() to sorted vector of (version, handler) pairs.
115
+ // TODO: Key by (domain, op_type) pair so we don't rely on names being unique
116
+ // across all domains.
117
+ std::unordered_map<std::string, ONNXOpVersions> onnx_ops_map_;
107
118
108
119
// mapping between string name and symbol
109
120
ValueSymbolMapping frontend_symbols_;
110
121
111
122
ModelInputShaper modelInputShaper_;
112
123
113
- using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);
114
-
115
- std::unordered_map<std::string, ImportHandlerType> import_handler_map_;
116
-
117
124
// The total number of elements in all initializers. This value is a rough
118
125
// counter of the number of parameters in a model.
119
126
int64_t num_of_parameters_ = 0 ;
@@ -682,45 +689,6 @@ class FrontendGenImpl {
682
689
node.op_type (), version, node.domain ());
683
690
}
684
691
685
- std::string GetImportVersionOfNode (const onnx::NodeProto &node) {
686
- int64_t version = GetDomainVersion (node.domain ());
687
- if (version == 0 )
688
- return " " ;
689
-
690
- LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : Importing ONNX"
691
- << node.op_type () << " (" << node.name () << " )"
692
- << " , Opset: " << version << " \n " );
693
-
694
- auto opset_list_it = op_dialect_version_map_.find (node.op_type ());
695
-
696
- // Custom ops may not be present in op_dialect_version_map_. If no version
697
- // info is found, treat as unversioned (no renaming).
698
- if (opset_list_it == op_dialect_version_map_.end ())
699
- return " " ;
700
-
701
- auto opset_list = opset_list_it->second ;
702
-
703
- // A new opset is added to onnx-mlir when it becomes imcompactible.
704
- // But the lowest opset in op_dialect_version_map_ is an exception.
705
- // It is the current opset when onnx-mlir project is started.
706
- // All opset lower than the last opset should use the last opset(version)
707
- if (node.domain ().compare (" ai.onnx.ml" ) != 0 &&
708
- version < opset_list.back () && version < MINIMUM_SUPPORTED_OPSET)
709
- llvm::outs () << " Warning: ONNX " << node.op_type ()
710
- << " in your model is using Opset " << version
711
- << " , which is quite old. Please consider regenerating your "
712
- " model with a newer Opset.\n " ;
713
-
714
- for (int i = opset_list.size () - 1 ; i > 0 ; i--) {
715
- if (version < opset_list[i - 1 ]) {
716
- LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : - use Opset "
717
- << opset_list[i] << " \n " );
718
- return " V" + std::to_string (opset_list[i]);
719
- }
720
- }
721
- return " " ;
722
- }
723
-
724
692
func::FuncOp CreateFuncOp (
725
693
std::string namePrefix, TypeRange operandTypes, TypeRange resultTypes) {
726
694
auto funcType = builder_.getFunctionType (operandTypes, resultTypes);
@@ -912,16 +880,58 @@ class FrontendGenImpl {
912
880
}
913
881
}
914
882
883
+ bool TryImportONNXNode (const onnx::NodeProto &node) {
884
+ int64_t version = GetDomainVersion (node.domain ());
885
+ if (version == 0 ) {
886
+ // Unknown domain.
887
+ return false ;
888
+ }
889
+
890
+ LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : Importing ONNX"
891
+ << node.op_type () << " (" << node.name () << " )"
892
+ << " , Opset: " << version << " \n " );
893
+
894
+ auto versions_it = onnx_ops_map_.find (node.op_type ());
895
+ if (versions_it == onnx_ops_map_.end ()) {
896
+ // Unknown op_type.
897
+ llvm::outs () << " Warning: ONNX " << node.op_type () << " from domain '"
898
+ << node.domain () << " ,"
899
+ << " in your model is unsupported.\n " ;
900
+ return false ;
901
+ }
902
+
903
+ const ONNXOpVersions &opVersions = versions_it->second ;
904
+
905
+ // A new opset is added to onnx-mlir when it becomes imcompatible.
906
+ // But the lowest opset in op_dialect_version_map_ is an exception.
907
+ // It is the current opset when onnx-mlir project is started.
908
+ // All opset lower than the last opset should use the last opset(version)
909
+ if (node.domain ().compare (" ai.onnx.ml" ) != 0 &&
910
+ version < opVersions.back ().version &&
911
+ version < MINIMUM_SUPPORTED_OPSET)
912
+ llvm::outs () << " Warning: ONNX " << node.op_type ()
913
+ << " in your model is using Opset " << version
914
+ << " , which is quite old. Please consider regenerating your "
915
+ " model with a newer Opset.\n " ;
916
+
917
+ ImportHandlerType handler = opVersions.front ().handler ;
918
+ for (int i = opVersions.size () - 1 ; i > 0 ; --i) {
919
+ if (version < opVersions[i - 1 ].version ) {
920
+ LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : - use Opset "
921
+ << opVersions[i].version << " \n " );
922
+ handler = opVersions[i].handler ;
923
+ }
924
+ }
925
+ (this ->*handler)(node);
926
+ return true ;
927
+ }
928
+
915
929
void ImportNode (const onnx::NodeProto &node) {
916
- std::string versionStr = GetImportVersionOfNode (node);
917
-
918
- // look up handler for the opName. If not found, create a node
919
- // for a custom op, and issue a warning.
920
- std::string versionedName = node.op_type () + versionStr;
921
- auto handler = import_handler_map_.find (versionedName);
922
- if (handler != import_handler_map_.end ()) {
923
- (this ->*(handler->second ))(node);
924
- } else {
930
+ bool imported = TryImportONNXNode (node);
931
+ if (!imported) {
932
+ LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " : Importing Custom op "
933
+ << node.op_type () << " (" << node.name () << " )"
934
+ << " , domain: '" << node.domain () << " '\n " );
925
935
ImportCustomNode (node);
926
936
}
927
937
}
@@ -932,18 +942,16 @@ class FrontendGenImpl {
932
942
if constexpr (std::is_base_of_v<ONNXOperationTrait<T>, T>) {
933
943
StringRef name = T::getONNXName ();
934
944
int version = T::getONNXSinceVersion ();
935
- op_dialect_version_map_[name.str ()].push_back (version);
936
-
937
- StringRef versionedName = T::getOperationName ();
938
- bool hadOnnxPrefix = versionedName.consume_front (" onnx." );
939
- assert (hadOnnxPrefix);
940
- import_handler_map_[versionedName.str ()] =
941
- &FrontendGenImpl::buildOperation<T>;
945
+ ImportHandlerType handler = &FrontendGenImpl::buildOperation<T>;
946
+ ONNXOpVersions &opVersions = onnx_ops_map_[name.str ()];
947
+ // Insert in descending version order:
948
+ auto it = opVersions.begin ();
949
+ while (it != opVersions.end () && it->version > version) {
950
+ ++it; // Skip past larger versions.
951
+ }
952
+ opVersions.insert (it, {version, handler});
942
953
}
943
954
});
944
- for (auto &[name, versions] : op_dialect_version_map_) {
945
- std::sort (versions.begin (), versions.end (), std::greater<int >());
946
- }
947
955
}
948
956
949
957
/* !
0 commit comments