Skip to content

Commit afa19b0

Browse files
committed
combine op_dialect_version_map_, import_handler_map_ into onnx_ops_map_
Signed-off-by: Soren Lassen <[email protected]>
1 parent 571b72b commit afa19b0

File tree

1 file changed

+72
-64
lines changed

1 file changed

+72
-64
lines changed

Diff for: src/Builder/FrontendDialectTransformer.cpp

+72-64
Original file line numberDiff line numberDiff line change
@@ -102,18 +102,25 @@ class FrontendGenImpl {
102102
ModuleOp module_;
103103
OpBuilder builder_;
104104

105-
// onnxop: list of versions for dialect
106-
std::unordered_map<std::string, std::vector<int>> op_dialect_version_map_;
105+
using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);
106+
107+
struct VersionedHandler {
108+
int version;
109+
ImportHandlerType handler;
110+
};
111+
112+
using ONNXOpVersions = SmallVector<VersionedHandler, 1>;
113+
114+
// Maps NodeProto::op_type() to sorted vector of (version, handler) pairs.
115+
// TODO: Key by (domain, op_type) pair so we don't rely on names being unique
116+
// across all domains.
117+
std::unordered_map<std::string, ONNXOpVersions> onnx_ops_map_;
107118

108119
// mapping between string name and symbol
109120
ValueSymbolMapping frontend_symbols_;
110121

111122
ModelInputShaper modelInputShaper_;
112123

113-
using ImportHandlerType = void (FrontendGenImpl::*)(const onnx::NodeProto &);
114-
115-
std::unordered_map<std::string, ImportHandlerType> import_handler_map_;
116-
117124
// The total number of elements in all initializers. This value is a rough
118125
// counter of the number of parameters in a model.
119126
int64_t num_of_parameters_ = 0;
@@ -682,45 +689,6 @@ class FrontendGenImpl {
682689
node.op_type(), version, node.domain());
683690
}
684691

685-
std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
686-
int64_t version = GetDomainVersion(node.domain());
687-
if (version == 0)
688-
return "";
689-
690-
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
691-
<< node.op_type() << " (" << node.name() << ")"
692-
<< ", Opset: " << version << "\n");
693-
694-
auto opset_list_it = op_dialect_version_map_.find(node.op_type());
695-
696-
// Custom ops may not be present in op_dialect_version_map_. If no version
697-
// info is found, treat as unversioned (no renaming).
698-
if (opset_list_it == op_dialect_version_map_.end())
699-
return "";
700-
701-
auto opset_list = opset_list_it->second;
702-
703-
// A new opset is added to onnx-mlir when it becomes imcompactible.
704-
// But the lowest opset in op_dialect_version_map_ is an exception.
705-
// It is the current opset when onnx-mlir project is started.
706-
// All opset lower than the last opset should use the last opset(version)
707-
if (node.domain().compare("ai.onnx.ml") != 0 &&
708-
version < opset_list.back() && version < MINIMUM_SUPPORTED_OPSET)
709-
llvm::outs() << "Warning: ONNX " << node.op_type()
710-
<< " in your model is using Opset " << version
711-
<< ", which is quite old. Please consider regenerating your "
712-
"model with a newer Opset.\n";
713-
714-
for (int i = opset_list.size() - 1; i > 0; i--) {
715-
if (version < opset_list[i - 1]) {
716-
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
717-
<< opset_list[i] << "\n");
718-
return "V" + std::to_string(opset_list[i]);
719-
}
720-
}
721-
return "";
722-
}
723-
724692
func::FuncOp CreateFuncOp(
725693
std::string namePrefix, TypeRange operandTypes, TypeRange resultTypes) {
726694
auto funcType = builder_.getFunctionType(operandTypes, resultTypes);
@@ -912,16 +880,58 @@ class FrontendGenImpl {
912880
}
913881
}
914882

883+
bool TryImportONNXNode(const onnx::NodeProto &node) {
884+
int64_t version = GetDomainVersion(node.domain());
885+
if (version == 0) {
886+
// Unknown domain.
887+
return false;
888+
}
889+
890+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
891+
<< node.op_type() << " (" << node.name() << ")"
892+
<< ", Opset: " << version << "\n");
893+
894+
auto versions_it = onnx_ops_map_.find(node.op_type());
895+
if (versions_it == onnx_ops_map_.end()) {
896+
// Unknown op_type.
897+
llvm::outs() << "Warning: ONNX " << node.op_type() << " from domain '"
898+
<< node.domain() << ","
899+
<< " in your model is unsupported.\n";
900+
return false;
901+
}
902+
903+
const ONNXOpVersions &opVersions = versions_it->second;
904+
905+
// A new opset is added to onnx-mlir when it becomes imcompatible.
906+
// But the lowest opset in op_dialect_version_map_ is an exception.
907+
// It is the current opset when onnx-mlir project is started.
908+
// All opset lower than the last opset should use the last opset(version)
909+
if (node.domain().compare("ai.onnx.ml") != 0 &&
910+
version < opVersions.back().version &&
911+
version < MINIMUM_SUPPORTED_OPSET)
912+
llvm::outs() << "Warning: ONNX " << node.op_type()
913+
<< " in your model is using Opset " << version
914+
<< ", which is quite old. Please consider regenerating your "
915+
"model with a newer Opset.\n";
916+
917+
ImportHandlerType handler = opVersions.front().handler;
918+
for (int i = opVersions.size() - 1; i > 0; --i) {
919+
if (version < opVersions[i - 1].version) {
920+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
921+
<< opVersions[i].version << "\n");
922+
handler = opVersions[i].handler;
923+
}
924+
}
925+
(this->*handler)(node);
926+
return true;
927+
}
928+
915929
void ImportNode(const onnx::NodeProto &node) {
916-
std::string versionStr = GetImportVersionOfNode(node);
917-
918-
// look up handler for the opName. If not found, create a node
919-
// for a custom op, and issue a warning.
920-
std::string versionedName = node.op_type() + versionStr;
921-
auto handler = import_handler_map_.find(versionedName);
922-
if (handler != import_handler_map_.end()) {
923-
(this->*(handler->second))(node);
924-
} else {
930+
bool imported = TryImportONNXNode(node);
931+
if (!imported) {
932+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing Custom op "
933+
<< node.op_type() << " (" << node.name() << ")"
934+
<< ", domain: '" << node.domain() << "'\n");
925935
ImportCustomNode(node);
926936
}
927937
}
@@ -932,18 +942,16 @@ class FrontendGenImpl {
932942
if constexpr (std::is_base_of_v<ONNXOperationTrait<T>, T>) {
933943
StringRef name = T::getONNXName();
934944
int version = T::getONNXSinceVersion();
935-
op_dialect_version_map_[name.str()].push_back(version);
936-
937-
StringRef versionedName = T::getOperationName();
938-
bool hadOnnxPrefix = versionedName.consume_front("onnx.");
939-
assert(hadOnnxPrefix);
940-
import_handler_map_[versionedName.str()] =
941-
&FrontendGenImpl::buildOperation<T>;
945+
ImportHandlerType handler = &FrontendGenImpl::buildOperation<T>;
946+
ONNXOpVersions &opVersions = onnx_ops_map_[name.str()];
947+
// Insert in descending version order:
948+
auto it = opVersions.begin();
949+
while (it != opVersions.end() && it->version > version) {
950+
++it; // Skip past larger versions.
951+
}
952+
opVersions.insert(it, {version, handler});
942953
}
943954
});
944-
for (auto &[name, versions] : op_dialect_version_map_) {
945-
std::sort(versions.begin(), versions.end(), std::greater<int>());
946-
}
947955
}
948956

949957
/*!

0 commit comments

Comments
 (0)