Skip to content

Commit 13f625a

Browse files
committed
more fixes
Signed-off-by: Soren Lassen <[email protected]>
1 parent 6fb1231 commit 13f625a

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

Diff for: src/Builder/FrontendDialectTransformer.cpp

+19-14
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ class FrontendGenImpl {
153153
}
154154
}
155155

156+
int64_t GetDomainVersion(const std::string &domain) {
157+
auto it = opset_map_.find(domain);
158+
if (it == opset_map_.end())
159+
return 0;
160+
return it->second;
161+
}
162+
156163
void BindOnnxName(const std::string &onnx_name, Value symbol) {
157164
frontend_symbols_.AddMapping(onnx_name, symbol);
158165
}
@@ -694,20 +701,21 @@ class FrontendGenImpl {
694701
}
695702

696703
const onnx::OpSchema *GetOpSchema(const onnx::NodeProto &node) {
697-
auto &domain = node.domain();
698-
auto version_it = opset_map_.find(domain);
699-
if (version_it == opset_map_.end())
704+
int64_t version = GetDomainVersion(node.domain());
705+
if (version == 0)
700706
return nullptr;
701-
auto version = version_it->second;
702-
return onnx::OpSchemaRegistry::Schema(node.op_type(), version, domain);
707+
return onnx::OpSchemaRegistry::Schema(
708+
node.op_type(), version, node.domain());
703709
}
704710

705711
std::string GetImportVersionOfNode(const onnx::NodeProto &node) {
706-
auto current_opset = opset_map_.find(node.domain())->second;
712+
int64_t version = GetDomainVersion(node.domain());
713+
if (version == 0)
714+
return "";
707715

708716
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": Importing ONNX"
709717
<< node.op_type() << " (" << node.name() << ")"
710-
<< ", Opset: " << current_opset << "\n");
718+
<< ", Opset: " << version << "\n");
711719

712720
auto opset_list_it = op_dialect_version_map_.find(node.op_type());
713721

@@ -723,23 +731,20 @@ class FrontendGenImpl {
723731
// It is the current opset when onnx-mlir project is started.
724732
// All opset lower than the last opset should use the last opset(version)
725733
if (node.domain().compare("ai.onnx.ml") != 0 &&
726-
current_opset < opset_list.back() &&
727-
current_opset < MINIMUM_SUPPORTED_OPSET)
734+
version < opset_list.back() && version < MINIMUM_SUPPORTED_OPSET)
728735
llvm::outs() << "Warning: ONNX " << node.op_type()
729-
<< " in your model is using Opset " << current_opset
736+
<< " in your model is using Opset " << version
730737
<< ", which is quite old. Please consider regenerating your "
731738
"model with a newer Opset.\n";
732739

733740
for (int i = opset_list.size() - 1; i > 0; i--) {
734-
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - testing Opset "
735-
<< opset_list[i - 1] << "\n");
736-
if (current_opset < opset_list[i - 1]) {
741+
if (version < opset_list[i - 1]) {
737742
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << ": - use Opset "
738743
<< opset_list[i] << "\n");
739744
return "V" + std::to_string(opset_list[i]);
740745
}
741746
}
742-
return std::string("");
747+
return "";
743748
}
744749

745750
func::FuncOp CreateFuncOp(

0 commit comments

Comments
 (0)