From 194a6e8f4d3ae361f5609e5167c0d58db8d1326b Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 16 Oct 2025 09:39:02 -0400 Subject: [PATCH 01/20] [FIRRTL] Add support for domain-connect driving instance-choice ports --- lib/Dialect/FIRRTL/FIRRTLOps.cpp | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index cca9cace386d..f9e1b2696a15 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -4278,6 +4278,14 @@ LogicalResult PropAssignOp::verify() { return success(); } +template +static FlatSymbolRefAttr getDomainTypeNameOfResult(T op, size_t i) { + auto info = op.getDomainInfo(); + if (info.empty()) + return {}; + return dyn_cast(info[i]); +} + static FlatSymbolRefAttr getDomainTypeName(Value value) { if (!isa(value.getType())) return {}; @@ -4297,13 +4305,10 @@ static FlatSymbolRefAttr getDomainTypeName(Value value) { if (auto result = dyn_cast(value)) { auto *op = result.getDefiningOp(); - if (auto instance = dyn_cast(op)) { - auto info = instance.getDomainInfo(); - if (info.empty()) - return {}; - auto attr = info[result.getResultNumber()]; - return dyn_cast(attr); - } + if (auto instance = dyn_cast(op)) + return getDomainTypeNameOfResult(instance, result.getResultNumber()); + if (auto instance = dyn_cast(op)) + return getDomainTypeNameOfResult(instance, result.getResultNumber()); return {}; } From 42dcdd9532786d9ed4bf827e2caab58201e36540 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 17 Oct 2025 12:41:32 -0400 Subject: [PATCH 02/20] [FIRRTL] getPortNameStr to getPortName, getPortName to getPortNameAttr --- .../Dialect/FIRRTL/FIRRTLDeclarations.td | 21 ++++++++++++------ lib/Dialect/FIRRTL/Export/FIREmitter.cpp | 4 ++-- lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp | 2 +- lib/Dialect/FIRRTL/FIRRTLOps.cpp | 22 +++++++++---------- lib/Dialect/FIRRTL/FIRRTLReductions.cpp | 6 ++--- lib/Dialect/FIRRTL/FIRRTLUtils.cpp | 3 +-- lib/Dialect/FIRRTL/Transforms/Dedup.cpp | 2 +- .../FIRRTL/Transforms/ExtractInstances.cpp | 4 ++-- .../FIRRTL/Transforms/FlattenMemory.cpp | 2 +- .../FIRRTL/Transforms/InferReadWrite.cpp | 2 +- lib/Dialect/FIRRTL/Transforms/InferResets.cpp | 2 +- .../FIRRTL/Transforms/LowerClasses.cpp | 2 +- .../FIRRTL/Transforms/LowerIntmodules.cpp | 2 +- .../FIRRTL/Transforms/LowerOpenAggs.cpp | 2 +- lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp | 6 ++--- lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp | 2 +- .../FIRRTL/Transforms/MemToRegOfVec.cpp | 2 +- 17 files changed, 46 insertions(+), 40 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td index b1f0ece66118..be2cb4652476 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td @@ -125,11 +125,12 @@ def InstanceOp : HardwareDeclOp<"instance", [ } /// Return the port name for the specified result number. - StringAttr getPortName(size_t resultNo) { + StringAttr getPortNameAttr(size_t resultNo) { return cast(getPortNames()[resultNo]); } - StringRef getPortNameStr(size_t resultNo) { - return getPortName(resultNo).getValue(); + + StringRef getPortName(size_t resultNo) { + return getPortNameAttr(resultNo).getValue(); } /// Hooks for port annotations. @@ -214,9 +215,14 @@ def InstanceChoiceOp : HardwareDeclOp<"instance_choice", [ } /// Return the port name for the specified result number. - StringAttr getPortName(size_t resultNo) { + StringAttr getPortNameAttr(size_t resultNo) { return cast(getPortNames()[resultNo]); } + + StringRef getPortName(size_t resultNo) { + return getPortNameAttr(resultNo).getValue(); + } + /// Return the default target attribute. FlatSymbolRefAttr getDefaultTargetAttr() { return llvm::cast(getModuleNamesAttr()[0]); @@ -313,9 +319,10 @@ def MemOp : HardwareDeclOp<"mem", [DeclareOpInterfaceMethods]> { size_t getMaskBits(); /// Return the port name for the specified result number. - StringAttr getPortName(size_t resultNo); - StringRef getPortNameStr(size_t resultNo) { - return getPortName(resultNo).getValue(); + StringAttr getPortNameAttr(size_t resultNo); + + StringRef getPortName(size_t resultNo) { + return getPortNameAttr(resultNo).getValue(); } /// Return the port type for the specified result number. diff --git a/lib/Dialect/FIRRTL/Export/FIREmitter.cpp b/lib/Dialect/FIRRTL/Export/FIREmitter.cpp index e9fb68ad408a..324f00e5eb2a 100644 --- a/lib/Dialect/FIRRTL/Export/FIREmitter.cpp +++ b/lib/Dialect/FIRRTL/Export/FIREmitter.cpp @@ -1126,7 +1126,7 @@ void Emitter::emitStatement(InstanceOp op) { portName.push_back('.'); unsigned baseLen = portName.size(); for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - portName.append(legalize(op.getPortName(i))); + portName.append(legalize(op.getPortNameAttr(i))); addValueName(op.getResult(i), portName); portName.resize(baseLen); } @@ -1153,7 +1153,7 @@ void Emitter::emitStatement(InstanceChoiceOp op) { portName.push_back('.'); unsigned baseLen = portName.size(); for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - portName.append(legalize(op.getPortName(i))); + portName.append(legalize(op.getPortNameAttr(i))); addValueName(op.getResult(i), portName); portName.resize(baseLen); } diff --git a/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp b/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp index d31008b383ca..26eb0506c607 100644 --- a/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp @@ -29,7 +29,7 @@ using llvm::StringRef; static LogicalResult updateExpandedPort(StringRef field, AnnoTarget &ref) { if (auto mem = dyn_cast(ref.getOp())) for (size_t p = 0, pe = mem.getPortNames().size(); p < pe; ++p) - if (mem.getPortNameStr(p) == field) { + if (mem.getPortName(p) == field) { ref = PortAnnoTarget(mem, p); return success(); } diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index f9e1b2696a15..d772e391e876 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -2736,7 +2736,7 @@ InstanceOp InstanceOp::cloneWithInsertedPorts( } newPortDirections.push_back(getPortDirection(i)); - newPortNames.push_back(getPortName(i)); + newPortNames.push_back(getPortNameAttr(i)); newPortTypes.push_back(getType(i)); newPortAnnos.push_back(getPortAnnotation(i)); newDomainInfo.push_back(getDomainInfo()[i]); @@ -2887,7 +2887,7 @@ void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { base = "inst"; for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) { - setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str()); + setNameFn(getResult(i), (base + "_" + getPortName(i)).str()); } } @@ -3219,7 +3219,7 @@ InstanceChoiceOp InstanceChoiceOp::cloneWithInsertedPorts( } newPortDirections.push_back(getPortDirection(i)); - newPortNames.push_back(getPortName(i)); + newPortNames.push_back(getPortNameAttr(i)); newPortTypes.push_back(getType(i)); newPortAnnos.push_back(getPortAnnotations()[i]); newDomainInfo.push_back(getDomainInfo()[i]); @@ -3343,7 +3343,7 @@ LogicalResult MemOp::verify() { FIRRTLType oldDataType; for (size_t i = 0, e = getNumResults(); i != e; ++i) { - auto portName = getPortName(i); + auto portName = getPortNameAttr(i); // Get a bundle type representing this port, stripping an outer // flip if it exists. If this is not a bundle<> or @@ -3453,10 +3453,10 @@ LogicalResult MemOp::verify() { // Error if the type of the current port was not the same as the // last port, but skip checking the first port. if (oldDataType && oldDataType != dataType) { - emitOpError() << "port " << getPortName(i) - << " has a different type than port " << getPortName(i - 1) - << " (expected " << oldDataType << ", but got " << dataType - << ")"; + emitOpError() << "port " << getPortNameAttr(i) + << " has a different type than port " + << getPortNameAttr(i - 1) << " (expected " << oldDataType + << ", but got " << dataType << ")"; return failure(); } @@ -3539,7 +3539,7 @@ SmallVector MemOp::getPorts() { for (size_t i = 0, e = getNumResults(); i != e; ++i) { // Each port is a bundle. auto portType = type_cast(getResult(i).getType()); - result.push_back({getPortName(i), getMemPortKindFromType(portType)}); + result.push_back({getPortNameAttr(i), getMemPortKindFromType(portType)}); } return result; } @@ -3596,7 +3596,7 @@ FIRRTLBaseType MemOp::getDataType() { .getElementType(dataFieldName); } -StringAttr MemOp::getPortName(size_t resultNo) { +StringAttr MemOp::getPortNameAttr(size_t resultNo) { return cast(getPortNames()[resultNo]); } @@ -3717,7 +3717,7 @@ void MemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { base = "mem"; for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) { - setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str()); + setNameFn(getResult(i), (base + "_" + getPortName(i)).str()); } } diff --git a/lib/Dialect/FIRRTL/FIRRTLReductions.cpp b/lib/Dialect/FIRRTL/FIRRTLReductions.cpp index 8e9dcae73a3d..9aad07765317 100644 --- a/lib/Dialect/FIRRTL/FIRRTLReductions.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLReductions.cpp @@ -414,7 +414,7 @@ struct InstanceStubber : public OpReduction { for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) { auto result = instOp.getResult(i); auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" + - instOp.getPortNameStr(i)); + instOp.getPortName(i)); auto wire = firrtl::WireOp::create(builder, result.getType(), name, firrtl::NameKindEnum::DroppableName, @@ -463,7 +463,7 @@ struct MemoryStubber : public OpReduction { for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) { auto result = memOp.getResult(i); auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" + - memOp.getPortNameStr(i)); + memOp.getPortName(i)); auto wire = firrtl::WireOp::create(builder, result.getType(), name, firrtl::NameKindEnum::DroppableName, @@ -1190,7 +1190,7 @@ struct EagerInliner : public OpReduction { for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) { auto result = instOp.getResult(i); auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" + - instOp.getPortNameStr(i)); + instOp.getPortName(i)); auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(), name, NameKindEnum::DroppableName, instOp.getPortAnnotation(i), StringAttr{}) diff --git a/lib/Dialect/FIRRTL/FIRRTLUtils.cpp b/lib/Dialect/FIRRTL/FIRRTLUtils.cpp index 289e1a526cab..0614e6efb98b 100644 --- a/lib/Dialect/FIRRTL/FIRRTLUtils.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLUtils.cpp @@ -571,8 +571,7 @@ static void getDeclName(Value value, SmallString<64> &string, bool nameSafe) { .Case([&](auto op) { string += op.getName(); string += nameSafe ? "_" : "."; - string += op.getPortName(cast(value).getResultNumber()) - .getValue(); + string += op.getPortName(cast(value).getResultNumber()); value = nullptr; }) .Case([&](auto op) { diff --git a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp index 81bfa3c64d4c..d57ffd3a4ef5 100644 --- a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp +++ b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp @@ -1596,7 +1596,7 @@ fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, continue; LLVM_DEBUG(llvm::dbgs() << "- Updating instance port \"" << instOp.getInstanceName() - << "." << instOp.getPortName(index).getValue() << "\" from " + << "." << instOp.getPortName(index) << "\" from " << oldType << " to " << newType << "\n"); // If the type changed we transform it back to the old type with an diff --git a/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp b/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp index 72b9421a424d..5d2a166649b7 100644 --- a/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp +++ b/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp @@ -563,7 +563,7 @@ void ExtractInstancesPass::extractInstances() { for (unsigned portIdx = 0; portIdx < numInstPorts; ++portIdx) { // Assemble the new port name as "_", where the prefix is // provided by the extraction annotation. - auto name = inst.getPortNameStr(portIdx); + auto name = inst.getPortName(portIdx); auto nameAttr = StringAttr::get( &getContext(), prefix.empty() ? Twine(name) : Twine(prefix) + "_" + name); @@ -913,7 +913,7 @@ void ExtractInstancesPass::groupInstances() { StringRef prefix(instPrefixNamesPair[inst].first); unsigned portNum = inst.getNumResults(); for (unsigned portIdx = 0; portIdx < portNum; ++portIdx) { - auto name = inst.getPortNameStr(portIdx); + auto name = inst.getPortName(portIdx); auto nameAttr = builder.getStringAttr( prefix.empty() ? Twine(name) : Twine(prefix) + "_" + name); PortInfo port{nameAttr, diff --git a/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp b/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp index 2da67137e816..933c0c2e39db 100644 --- a/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp +++ b/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp @@ -138,7 +138,7 @@ struct FlattenMemoryPass auto result = memOp.getResult(index); auto wire = WireOp::create(builder, result.getType(), (memOp.getName() + "_" + - memOp.getPortName(index).getValue()) + memOp.getPortName(index)) .str()) .getResult(); result.replaceAllUsesWith(wire); diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 594f65c8591d..1dc2096b189f 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -89,7 +89,7 @@ struct InferReadWritePass Attribute portAnno; portAnno = memOp.getPortAnnotation(portIt.index()); if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) { - resultNames.push_back(memOp.getPortName(portIt.index())); + resultNames.push_back(memOp.getPortNameAttr(portIt.index())); resultTypes.push_back(memOp.getResult(portIt.index()).getType()); portAnnotations.push_back(portAnno); continue; diff --git a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp index dfef3719f3b3..dceb7c8a123e 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp @@ -705,7 +705,7 @@ static bool getDeclName(Value value, SmallString<32> &string) { string += op.getName(); string += "."; string += - op.getPortName(cast(value).getResultNumber()).getValue(); + op.getPortName(cast(value).getResultNumber()); return true; }) .Case([&](auto op) { diff --git a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp index 5722d989154a..af78bf5c804d 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp @@ -1439,7 +1439,7 @@ updateInstanceInClass(InstanceOp firrtlInstance, hw::HierPathOp hierPath, // The path to the field is just this output's name. auto objectFieldPath = builder.getArrayAttr({FlatSymbolRefAttr::get( - firrtlInstance.getPortName(result.getResultNumber()))}); + firrtlInstance.getPortNameAttr(result.getResultNumber()))}); // Create the field access. auto objectField = ObjectFieldOp::create( diff --git a/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp b/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp index eb4442a73886..df64d628d161 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp @@ -108,7 +108,7 @@ void LowerIntmodulesPass::runOnOperation() { } outputs.push_back( OutputInfo{inst.getResult(idx), - BundleType::BundleElement(inst.getPortName(idx), + BundleType::BundleElement(inst.getPortNameAttr(idx), /*isFlip=*/false, ftype)}); } diff --git a/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp b/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp index bffbc3057569..375414ca866c 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp @@ -547,7 +547,7 @@ LogicalResult Visitor::visitDecl(InstanceOp op) { // If not identity, mark this port for eventual removal. portsToErase.set(newIndex); - auto portName = op.getPortName(index); + auto portName = op.getPortNameAttr(index); auto portDirection = op.getPortDirection(index); auto portDomain = op.getPortDomain(index); auto loc = op.getLoc(); diff --git a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp index ba54761c42f7..b62e03209389 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp @@ -997,7 +997,7 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) { } auto wire = builder->create( result.getType(), - (op.getName() + "_" + op.getPortName(index).getValue()).str()); + (op.getName() + "_" + op.getPortName(index)).str()); oldPorts.push_back(wire); result.replaceAllUsesWith(wire.getResult()); } @@ -1498,13 +1498,13 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector fieldTypes; if (!peelType(srcType, fieldTypes, mode)) { newDirs.push_back(op.getPortDirection(i)); - newNames.push_back(op.getPortName(i)); + newNames.push_back(op.getPortNameAttr(i)); newDomains.push_back(builder->getArrayAttr({})); resultTypes.push_back(srcType); newPortAnno.push_back(oldPortAnno[i]); } else { skip = false; - auto oldName = op.getPortNameStr(i); + auto oldName = op.getPortName(i); auto oldDir = op.getPortDirection(i); // Store the flat type for the new bundle type. for (const auto &field : fieldTypes) { diff --git a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp index d7eb00a47272..596a6c87b799 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp @@ -806,7 +806,7 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase { for (const auto &res : llvm::enumerate(mem.getResults())) { if (isa(mem.getResult(res.index()).getType())) continue; - resultNames.push_back(mem.getPortName(res.index())); + resultNames.push_back(mem.getPortNameAttr(res.index())); resultTypes.push_back(res.value().getType()); portAnnotations.push_back(mem.getPortAnnotation(res.index())); oldResults.push_back(res.value()); diff --git a/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp b/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp index ef2c0dcecf09..b671ebd31af6 100644 --- a/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp +++ b/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp @@ -386,7 +386,7 @@ struct MemToRegOfVecPass // simpler to delete the memOp. auto wire = WireOp::create( builder, result.getType(), - (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(), + (memOp.getName() + "_" + memOp.getPortName(index)).str(), memOp.getNameKind()); result.replaceAllUsesWith(wire.getResult()); result = wire.getResult(); From 3f9abdf1c65d9a4649314ff97cb176fd2f69a444 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 16 Oct 2025 11:28:00 -0400 Subject: [PATCH 03/20] [FIRRTL] Add missing port helpers to instance ops --- .../circt/Dialect/FIRRTL/FIRRTLDeclarations.td | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td index be2cb4652476..42705051d7bf 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td @@ -119,6 +119,11 @@ def InstanceOp : HardwareDeclOp<"instance", [ CArg<"hw::InnerSymAttr", "hw::InnerSymAttr()">:$innerSym)>]; let extraClassDeclaration = [{ + /// Return the number of ports for this instance. + size_t getNumPorts() { + return getNumResults(); + } + /// Return the port direction for the specified result number. Direction getPortDirection(size_t resultNo) { return direction::get(getPortDirections()[resultNo]); @@ -133,6 +138,10 @@ def InstanceOp : HardwareDeclOp<"instance", [ return getPortNameAttr(resultNo).getValue(); } + Location getPortLocation(size_t) { + return getLoc(); + } + /// Hooks for port annotations. ArrayAttr getPortAnnotation(unsigned portIdx); void setAllPortAnnotations(ArrayRef annotations); @@ -209,6 +218,11 @@ def InstanceChoiceOp : HardwareDeclOp<"instance_choice", [ let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ + /// Return the number of ports for this instance. + size_t getNumPorts() { + return getNumResults(); + } + /// Return the port direction for the specified result number. Direction getPortDirection(size_t resultNo) { return direction::get(getPortDirections()[resultNo]); @@ -223,6 +237,10 @@ def InstanceChoiceOp : HardwareDeclOp<"instance_choice", [ return getPortNameAttr(resultNo).getValue(); } + Location getPortLocation(size_t resultNo) { + return getLoc(); + } + /// Return the default target attribute. FlatSymbolRefAttr getDefaultTargetAttr() { return llvm::cast(getModuleNamesAttr()[0]); From f68722405c1acd7d93932913f9480f36e9e8761f Mon Sep 17 00:00:00 2001 From: Schuyler Eldridge Date: Fri, 8 Aug 2025 16:25:29 -0400 Subject: [PATCH 04/20] [FIRRTL] Add InferDomains pass Add a pass that does domain inference and checking. This is used to verify the legality of a FIRRTL circuit with respect to its domains. E.g., this pass is intended to be used for checking for illegal clock domain crossings. Signed-off-by: Schuyler Eldridge --- include/circt/Dialect/FIRRTL/Passes.td | 13 + lib/Dialect/FIRRTL/Transforms/CMakeLists.txt | 1 + .../FIRRTL/Transforms/InferDomains.cpp | 425 ++++++++++++++++++ test/Dialect/FIRRTL/infer-domains-errors.mlir | 63 +++ test/Dialect/FIRRTL/infer-domains.mlir | 234 ++++++++++ 5 files changed, 736 insertions(+) create mode 100644 lib/Dialect/FIRRTL/Transforms/InferDomains.cpp create mode 100644 test/Dialect/FIRRTL/infer-domains-errors.mlir create mode 100644 test/Dialect/FIRRTL/infer-domains.mlir diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 15cf686d618c..caf8eeb7dffb 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -909,6 +909,19 @@ def CheckLayers : Pass<"firrtl-check-layers", "firrtl::CircuitOp"> { }]; } +def InferDomains : Pass<"firrtl-infer-domains", "firrtl::CircuitOp"> { + let summary = "Infer and type check all firrtl domains"; + let description = [{ + This pass does domain inference on a FIRRTL circuit. The end result of this + is either a corrrctly domain-checked FIRRTL circuit or failure with verbose + error messages indicating why the FIRRTL circuit has illegal domain + constructs. + + E.g., this pass can be used to check for illegal clock-domain-crossings if + clock domains are specified for signals in the design. + }]; +} + def LowerDomains : Pass<"firrtl-lower-domains", "firrtl::CircuitOp"> { let summary = "lower domain information to properties"; let description = [{ diff --git a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt index f8a1ec9e4045..d24f11c7563b 100755 --- a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt +++ b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms GrandCentral.cpp IMConstProp.cpp IMDeadCodeElim.cpp + InferDomains.cpp InferReadWrite.cpp InferResets.cpp InferWidths.cpp diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp new file mode 100644 index 000000000000..75cede2f77ed --- /dev/null +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -0,0 +1,425 @@ +//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements FIRRTL domain inference and checking with canonical +// domain representation. Domain sequences are canonicalized by sorting and +// removing duplicates, making domain order irrelevant and allowing duplicate +// domains to be treated as equivalent. The result of this pass is either a +// correctly domain-inferred circuit or pass failure if the circuit contains +// illegal domain crossings. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/FIRRTL/FIRRTLOps.h" +#include "circt/Dialect/FIRRTL/Passes.h" +#include "circt/Support/Debug.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "firrtl-infer-domains" + +namespace circt { +namespace firrtl { +#define GEN_PASS_DEF_INFERDOMAINS +#include "circt/Dialect/FIRRTL/Passes.h.inc" +} // namespace firrtl +} // namespace circt + +using namespace circt; +using namespace firrtl; + +namespace { + +/// Union-Find data structure for domain variables using LLVM's +/// EquivalenceClasses. Handles sequences of domains with canonical +/// representation where domain order is irrelevant and duplicates are removed. +/// This allows 'domains %A, %B' to be equivalent to 'domains %B, %A' and +/// 'domains %A, %A' to be equivalent to 'domains %A'. +class DomainUnionFind { +public: + using DomainSequence = llvm::SmallVector; + +private: + /// Canonicalize a domain sequence by sorting and removing duplicates. + /// This ensures that domain order doesn't matter and duplicate domains + /// are treated as equivalent. For example: + /// - 'domains %A, %B' and 'domains %B, %A' become the same canonical form + /// - 'domains %A, %A' becomes 'domains %A' + static DomainSequence canonicalizeDomains(ArrayRef domains) { + DomainSequence canonical(domains.begin(), domains.end()); + + // Sort domains deterministically. Use block arg order if all are block + // args. Fallback to opaque pointer comparison otherwise. Note: the opaque + // pointer comparison is only stable within a run of MLIR and should not be + // relied upon for determinism outside of a run. + llvm::sort(canonical, [](Value a, Value b) { + auto aArg = dyn_cast(a); + auto bArg = dyn_cast(b); + if (aArg && bArg) + return aArg.getArgNumber() < bArg.getArgNumber(); + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + }); + + // Remove duplicates to handle cases like 'domains %A, %A'. + canonical.erase(std::unique(canonical.begin(), canonical.end()), + canonical.end()); + + return canonical; + } + +public: + /// Get or create a domain variable for a value. + Value getDomainVar(Value value) { + auto it = valueToVar.find(value); + if (it != valueToVar.end()) + return it->second; + + // Create a new representative value for this domain variable. + Value representative = value; + valueToVar[value] = representative; + equivalenceClasses.insert(representative); + return representative; + } + + /// Union two domain variables. Returns success if successful, failure if + /// there's a domain conflict. With canonical representation, conflicts only + /// occur when connecting values with truly different domain sets (not just + /// different ordering). + [[nodiscard]] LogicalResult unifyDomains(Value var1, Value var2) { + Value rep1 = equivalenceClasses.getLeaderValue(var1); + Value rep2 = equivalenceClasses.getLeaderValue(var2); + + if (rep1 == rep2) + return success(); + + // Check for domain conflicts using canonical representation. + const DomainSequence &domains1 = getDomains(rep1); + const DomainSequence &domains2 = getDomains(rep2); + + // Since domains are already canonical, we can compare them directly. + // Conflicts only occur when connecting truly different domain sets. For + // example: 'domains %A' vs 'domains %B' is a conflict, but 'domains %A, %B' + // vs 'domains %B, %A' is not (same canonical form). + if (!domains1.empty() && !domains2.empty() && domains1 != domains2) { + // Conflict: both have different concrete domain sequences + return failure(); + } + + // Merge the concrete domain sequences (both are already canonical) + const DomainSequence &mergedDomains = + !domains1.empty() ? domains1 : domains2; + + // Union the equivalence classes + equivalenceClasses.unionSets(rep1, rep2); + + // Set the merged domain sequence on the new representative + if (!mergedDomains.empty()) { + Value newRep = equivalenceClasses.getLeaderValue(rep1); + concreteDomains[newRep] = mergedDomains; + } + + return success(); + } + + /// Set the concrete domain sequence for a variable. + /// The domain sequence is automatically canonicalized to ensure consistent + /// representation regardless of input order or duplicates. + void setDomains(Value var, ArrayRef domains) { + Value rep = equivalenceClasses.getLeaderValue(var); + concreteDomains[rep] = canonicalizeDomains(domains); + } + + /// Get the concrete domain sequence for a variable. + /// Returns the canonical domain sequence (sorted, no duplicates). + const DomainSequence &getDomains(Value var) { + Value rep = equivalenceClasses.getLeaderValue(var); + auto it = concreteDomains.find(rep); + if (it != concreteDomains.end()) + return it->second; + + // Return empty sequence if no domains are set + static const DomainSequence emptySequence; + return emptySequence; + } + + /// Clear all state. + void clear() { + equivalenceClasses = llvm::EquivalenceClasses(); + concreteDomains.clear(); + valueToVar.clear(); + } + +private: + llvm::EquivalenceClasses equivalenceClasses; + llvm::DenseMap concreteDomains; + llvm::DenseMap valueToVar; +}; + +/// Domain inference and checking pass implementation. +/// Uses canonical domain representation to allow domain order independence +/// and duplicate domain handling. +class InferDomainsPass + : public circt::firrtl::impl::InferDomainsBase { + +public: + InferDomainsPass() = default; + + void runOnOperation() override; + + std::unique_ptr clonePass() const override { + return std::make_unique(); + } + +private: + /// Union-find structure for domain variables. + DomainUnionFind domainUF; + + /// Track if any domain crossing errors were found. + bool hasErrors = false; + + /// Process a module and infer domains. + LogicalResult processModule(FModuleOp module); + + /// Process domain constraints from operations. + void processOperation(Operation *op); + + /// Unify domains of two values. + [[nodiscard]] LogicalResult unifyDomains(Value lhs, Value rhs); + + /// Set explicit domain sequence for a value. + void setExplicitDomains(Value value, OperandRange domains); + + /// Helper to add domain sequence notes to diagnostics. + void addDomainSequenceNotes(InFlightDiagnostic &diag, + const DomainUnionFind::DomainSequence &domains, + StringRef prefix); + + /// Helper to emit domain crossing error with detailed notes. + /// Sets hasErrors flag. + void emitDomainCrossingError(Operation *op, Value lhs, Value rhs, + StringRef errorMessage, StringRef lhsLabel, + StringRef rhsLabel); +}; + +} // namespace + +LogicalResult InferDomainsPass::unifyDomains(Value lhs, Value rhs) { + Value lhsVar = domainUF.getDomainVar(lhs); + Value rhsVar = domainUF.getDomainVar(rhs); + + return domainUF.unifyDomains(lhsVar, rhsVar); +} + +void InferDomainsPass::setExplicitDomains(Value value, OperandRange domains) { + Value domainVar = domainUF.getDomainVar(value); + // Convert OperandRange to SmallVector for setDomains. + llvm::SmallVector domainVec(domains.begin(), domains.end()); + domainUF.setDomains(domainVar, domainVec); +} + +void InferDomainsPass::addDomainSequenceNotes( + InFlightDiagnostic &diag, const DomainUnionFind::DomainSequence &domains, + StringRef prefix) { + if (domains.empty()) + return; + + for (size_t i = 0; i < domains.size(); ++i) { + if (domains[i]) { + std::string message = prefix.str(); + if (domains.size() > 1) { + message += " (domain " + std::to_string(i + 1) + " of " + + std::to_string(domains.size()) + ")"; + } + message += " is in domain defined here"; + diag.attachNote(domains[i].getLoc()).append(message); + } + } +} + +void InferDomainsPass::emitDomainCrossingError(Operation *op, Value lhs, + Value rhs, + StringRef errorMessage, + StringRef lhsLabel, + StringRef rhsLabel) { + // Get the concrete domain sequences for error reporting + Value lhsVar = domainUF.getDomainVar(lhs); + Value rhsVar = domainUF.getDomainVar(rhs); + const auto &lhsDomains = domainUF.getDomains(lhsVar); + const auto &rhsDomains = domainUF.getDomains(rhsVar); + + auto diag = op->emitError(errorMessage); + addDomainSequenceNotes(diag, lhsDomains, lhsLabel); + addDomainSequenceNotes(diag, rhsDomains, rhsLabel); + hasErrors = true; +} + +void InferDomainsPass::processOperation(Operation *op) { + // Error on unhandled operations. + if (isa(op)) { + llvm::errs() << "InferDomains cannot yet handle " << op->getName() << "\n"; + return signalPassFailure(); + } + + // Handle unsafe domain casts + if (auto domainCastOp = dyn_cast(op)) { + if (!domainCastOp.getDomains().empty()) { + // Explicitly cast to specified domain sequence + setExplicitDomains(domainCastOp.getResult(), domainCastOp.getDomains()); + } + return; + } + + // For all operations (including connections), propagate domains from operands + // to results This is a conservative approach - all operands and results share + // domains + if (!op->getOperands().empty()) { + Value firstOperand = op->getOperand(0); + for (auto operand : op->getOperands()) { + if (failed(unifyDomains(firstOperand, operand))) { + emitDomainCrossingError(op, firstOperand, operand, + "illegal domain crossing in operation", + "first operand", "operand"); + } + } + for (auto result : op->getResults()) { + if (failed(unifyDomains(firstOperand, result))) { + emitDomainCrossingError(op, firstOperand, result, + "illegal domain crossing in operation", + "operand", "result"); + } + } + } +} + +LogicalResult InferDomainsPass::processModule(FModuleOp module) { + LLVM_DEBUG(llvm::dbgs() << "Processing module: " << module.getName() << "\n"); + + // Process module ports - domain ports define explicit domains + for (auto [index, port] : llvm::enumerate(module.getPorts())) { + Value portValue = module.getArgument(index); + + // This is a domain port. + if (isa(port.type)) { + Value domainVar = domainUF.getDomainVar(portValue); + domainUF.setDomains(domainVar, {portValue}); + continue; + } + + // This is a port with explicit domain information. + auto domains = cast(port.domains); + if (domains.empty()) + continue; + SmallVector domainPorts; + for (auto domain : domains) { + auto index = cast(domain).getUInt(); + domainPorts.push_back(module.getArgument(index)); + } + Value domainVar = domainUF.getDomainVar(portValue); + domainUF.setDomains(domainVar, domainPorts); + } + + // Process all operations in the module + module.walk([&](Operation *op) { processOperation(op); }); + + // Check if any errors were found during processing + if (hasErrors) + return failure(); + + // Update domain information for all non-domain ports + SmallVector newDomainInfo; + bool anyUpdated = false; + + for (auto [index, port] : llvm::enumerate(module.getPorts())) { + // Skip domain ports - they don't need domain information + if (isa(port.type)) { + newDomainInfo.push_back(port.domains + ? port.domains + : ArrayAttr::get(module.getContext(), {})); + continue; + } + + // Get the inferred domains for this port + Value portValue = module.getArgument(index); + const auto &inferredDomains = domainUF.getDomains(portValue); + + // Convert domain values to domain indices + SmallVector domainIndices; + for (Value domain : inferredDomains) { + if (auto blockArg = dyn_cast(domain)) { + // This is a reference to a domain port + domainIndices.push_back(IntegerAttr::get( + IntegerType::get(module.getContext(), 32, IntegerType::Unsigned), + blockArg.getArgNumber())); + } + } + + ArrayAttr newDomains = ArrayAttr::get(module.getContext(), domainIndices); + + // Check if this is different from the existing domain information + if (!port.domains || port.domains != newDomains) { + anyUpdated = true; + } + + newDomainInfo.push_back(newDomains); + } + + // Update the module's domain information if anything changed + if (anyUpdated) { + module->setAttr("domainInfo", + ArrayAttr::get(module.getContext(), newDomainInfo)); + } + + LLVM_DEBUG({ + for (auto [index, port] : llvm::enumerate(module.getPorts())) { + llvm::dbgs() << " - port: " << port.getName() << "\n" + << " domains:\n"; + auto domains = domainUF.getDomains(module.getArgument(index)); + if (domains.empty()) { + llvm::dbgs() << " - inferred\n"; + continue; + } + for (auto domain : domains) { + if (auto port = dyn_cast(domain)) { + llvm::dbgs() << " - " << module.getPortName(port.getArgNumber()) + << "\n"; + continue; + } + } + } + }); + + return success(); +} + +void InferDomainsPass::runOnOperation() { + LLVM_DEBUG(debugPassHeader(this) << "\n"); + + // Clear state from any previous runs + domainUF.clear(); + hasErrors = false; + + auto circuit = getOperation(); + + // Process each module in the circuit + for (auto module : circuit.getOps()) { + if (failed(processModule(module))) { + signalPassFailure(); + return; + } + } + + // Signal failure if any domain crossing errors were found + if (hasErrors) { + signalPassFailure(); + } + + LLVM_DEBUG(debugFooter() << "\n"); +} diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir new file mode 100644 index 000000000000..aa7aa2bdd773 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -0,0 +1,63 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --verify-diagnostics --split-input-file + +// Test case 1: Illegal domain crossing - both matchingconnect and connect should fail +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain {} + firrtl.module @IllegalDomainCrossing( + // expected-note@below {{operand is in domain defined here}} + in %A: !firrtl.domain of @ClockDomain, + // expected-note@below {{first operand is in domain defined here}} + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // expected-error @below {{illegal domain crossing in operation}} + firrtl.connect %b, %a : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// ----- + +// Test case 2: Multiple domain crossings +firrtl.circuit "MultipleDomainCrossings" { + firrtl.domain @ClockDomain {} + firrtl.module @MultipleDomainCrossings( + // expected-note@below {{operand is in domain defined here}} + in %A: !firrtl.domain of @ClockDomain, + // expected-note@below {{first operand is in domain defined here}} + in %B: !firrtl.domain of @ClockDomain, + // expected-note@below {{first operand is in domain defined here}} + in %C: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%B], + out %c: !firrtl.uint<1> domains [%C] + ) { + // expected-error@below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // expected-error@below {{illegal domain crossing in operation}} + firrtl.matchingconnect %c, %a : !firrtl.uint<1> + } +} + +// ----- + +// Test case 3: Domain sequence mismatch - different lengths +firrtl.circuit "SequenceLengthMismatch" { + firrtl.domain @ClockDomain {} + firrtl.module @SequenceLengthMismatch( + // expected-note@below {{operand (domain 1 of 2) is in domain defined here}} + in %A: !firrtl.domain of @ClockDomain, + // expected-note@below {{first operand is in domain defined here}} + // expected-note@below {{operand (domain 2 of 2) is in domain defined here}} + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error@below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir new file mode 100644 index 000000000000..c065c3aa7bf8 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -0,0 +1,234 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --split-input-file | FileCheck %s + +// Test case 1: Legal domain usage - no crossing +firrtl.circuit "LegalDomains" { + firrtl.domain @ClockDomain {} + firrtl.module @LegalDomains( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // Connecting within the same domain is legal. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "LegalDomains" + +// ----- + +// Test case 2: Domain inference through connections +firrtl.circuit "DomainInference" { + firrtl.domain @ClockDomain {} + firrtl.module @DomainInference( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> + ) { + %b = firrtl.wire : !firrtl.uint<1> // No explicit domain + + // This should infer that %b is in domain %A. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // This should be legal since %b is now inferred to be in domain %A. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "DomainInference" +// CHECK: out %c: !firrtl.uint<1> domains [%A] + +// ----- + +// Test case 3: Unsafe domain cast +firrtl.circuit "UnsafeDomainCast" { + firrtl.domain @ClockDomain {} + firrtl.module @UnsafeDomainCast( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> domains [%B] + ) { + // Unsafe cast from domain A to domain B. + %b = firrtl.unsafe_domain_cast %a domains %B : !firrtl.uint<1> + + // This should be legal since we explicitly cast. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" + +// ----- + +// Test case 4: Domain sequence matching - legal case +firrtl.circuit "LegalSequences" { + firrtl.domain @ClockDomain {} + firrtl.module @LegalSequences( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%A, %B] + ) { + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "LegalSequences" + +// ----- + +// Test case 5: Domain sequence order equivalence - should be legal +firrtl.circuit "SequenceOrderEquivalence" { + firrtl.domain @ClockDomain {} + firrtl.module @SequenceOrderEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%B, %A] + ) { + // This should be legal since domain order doesn't matter in canonical representation + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "SequenceOrderEquivalence" + +// ----- + +// Test case 6: Domain sequence inference +firrtl.circuit "SequenceInference" { + firrtl.domain @ClockDomain {} + firrtl.module @SequenceInference( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %d: !firrtl.uint<1> + ) { + %c = firrtl.wire : !firrtl.uint<1> + + // %c should infer domain sequence [%A, %B] + firrtl.matchingconnect %c, %a : !firrtl.uint<1> + + // This should be legal since %c has inferred [%A, %B] + firrtl.matchingconnect %d, %c : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "SequenceInference" +// CHECK: out %d: !firrtl.uint<1> domains [%A, %B] + +// ----- + +// Test case 7: Domain duplicate equivalence - should be legal +firrtl.circuit "DuplicateDomainEquivalence" { + firrtl.domain @ClockDomain {} + firrtl.module @DuplicateDomainEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // This should be legal since duplicate domains are canonicalized + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "DuplicateDomainEquivalence" + +// ----- + +// Test case 8: Unsafe domain cast with sequences +firrtl.circuit "UnsafeSequenceCast" { + firrtl.domain @ClockDomain {} + firrtl.module @UnsafeSequenceCast( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %C: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %c: !firrtl.uint<1> domains [%C] + ) { + %0 = firrtl.unsafe_domain_cast %a domains %C : !firrtl.uint<1> + firrtl.matchingconnect %c, %0 : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "UnsafeSequenceCast" +// CHECK: out %c: !firrtl.uint<1> domains [%C] + +// ----- + +// Test case 9: Multiple port domain inference +firrtl.circuit "MultiplePortInference" { + firrtl.domain @ClockDomain {} + firrtl.module @MultiplePortInference( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %inputA: !firrtl.uint<1> domains [%A], + in %inputB: !firrtl.uint<1> domains [%B], + in %inputAB: !firrtl.uint<1> domains [%A, %B], + out %outputA: !firrtl.uint<1>, + out %outputB: !firrtl.uint<1>, + out %outputAB: !firrtl.uint<1> + ) { + firrtl.matchingconnect %outputA, %inputA : !firrtl.uint<1> + firrtl.matchingconnect %outputB, %inputB : !firrtl.uint<1> + firrtl.matchingconnect %outputAB, %inputAB : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "MultiplePortInference" +// CHECK: out %outputA: !firrtl.uint<1> domains [%A] +// CHECK: out %outputB: !firrtl.uint<1> domains [%B] +// CHECK: out %outputAB: !firrtl.uint<1> domains [%A, %B] + +// ----- + +// Test case 10: Different port types domain inference +firrtl.circuit "DifferentPortTypes" { + firrtl.domain @ClockDomain {} + firrtl.module @DifferentPortTypes( + in %A: !firrtl.domain of @ClockDomain, + in %uint_input: !firrtl.uint<8> domains [%A], + in %sint_input: !firrtl.sint<4> domains [%A], + out %uint_output: !firrtl.uint<8>, + out %sint_output: !firrtl.sint<4> + ) { + firrtl.matchingconnect %uint_output, %uint_input : !firrtl.uint<8> + firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> + } +} +// CHECK-LABEL: firrtl.circuit "DifferentPortTypes" +// CHECK: out %uint_output: !firrtl.uint<8> domains [%A] +// CHECK: out %sint_output: !firrtl.sint<4> domains [%A] + +// ----- + +// Test case 11: Domain inference through wires +firrtl.circuit "DomainInferenceThroughWires" { + firrtl.domain @ClockDomain {} + firrtl.module @DomainInferenceThroughWires( + in %A: !firrtl.domain of @ClockDomain, + in %input: !firrtl.uint<1> domains [%A], + out %output: !firrtl.uint<1> + ) { + %wire1 = firrtl.wire : !firrtl.uint<1> + %wire2 = firrtl.wire : !firrtl.uint<1> + + firrtl.matchingconnect %wire1, %input : !firrtl.uint<1> + firrtl.matchingconnect %wire2, %wire1 : !firrtl.uint<1> + firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "DomainInferenceThroughWires" +// CHECK: out %output: !firrtl.uint<1> domains [%A] + +// ----- + +// Test case 12: Register inference +firrtl.circuit "RegisterInference" { + firrtl.domain @ClockDomain {} + firrtl.module @RegisterInference( + in %A: !firrtl.domain of @ClockDomain, + in %clock: !firrtl.clock domains [%A], + in %d: !firrtl.uint<1>, + out %q: !firrtl.uint<1> + ) { + %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> + firrtl.matchingconnect %r, %d : !firrtl.uint<1> + firrtl.matchingconnect %q, %r : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "RegisterInference" +// CHECK: in %d: !firrtl.uint<1> domains [%A] +// CHECK: out %q: !firrtl.uint<1> domains [%A] From 93f1893c5bd26d2860544a4fc893e08520cfe89f Mon Sep 17 00:00:00 2001 From: Robert Young Date: Tue, 23 Sep 2025 23:15:20 -0400 Subject: [PATCH 05/20] Check in --- lib/Dialect/FIRRTL/FIRRTLOps.cpp | 28 + .../FIRRTL/Transforms/InferDomains.cpp | 976 ++++++++++++------ 2 files changed, 701 insertions(+), 303 deletions(-) diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index d772e391e876..ce699a5b9b01 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -4343,6 +4343,34 @@ LogicalResult DomainDefineOp::verify() { return success(); } +LogicalResult DomainDefineOp::verify() { + if (failed(checkConnectFlow(*this))) + return failure(); + + for (auto *user : getDest().getUsers()) { + auto connection = dyn_cast(user); + if (!connection || connection == *this || connection.getDest() != getDest()) + continue; + return emitError("destination domains cannot be reused by multiple " + "operations, it can only capture a unique dataflow"); + } + + return success(); +} + +LogicalResult DomainDefineOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // if (failed(verifyPortSymbolUses(*this, symbolTable))) + // return failure(); + + // auto circuitOp = getOperation()->getParentOfType(); + // for (auto layer : getLayers()) { + // if (!symbolTable.lookupSymbolIn(circuitOp, cast(layer))) + // return emitOpError() << "enables undefined layer '" << layer << "'"; + // } + + return success(); +} + void WhenOp::createElseRegion() { assert(!hasElseRegion() && "already has an else region"); getElseRegion().push_back(new Block()); diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 75cede2f77ed..42e38814b619 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -15,14 +15,17 @@ // //===----------------------------------------------------------------------===// +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/FIRRTL/FIRRTLOps.h" #include "circt/Dialect/FIRRTL/Passes.h" #include "circt/Support/Debug.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/TrailingObjects.h" #define DEBUG_TYPE "firrtl-infer-domains" @@ -35,390 +38,757 @@ namespace firrtl { using namespace circt; using namespace firrtl; +using llvm::TrailingObjects; namespace { -/// Union-Find data structure for domain variables using LLVM's -/// EquivalenceClasses. Handles sequences of domains with canonical -/// representation where domain order is irrelevant and duplicates are removed. -/// This allows 'domains %A, %B' to be equivalent to 'domains %B, %A' and -/// 'domains %A, %A' to be equivalent to 'domains %A'. -class DomainUnionFind { -public: - using DomainSequence = llvm::SmallVector; - -private: - /// Canonicalize a domain sequence by sorting and removing duplicates. - /// This ensures that domain order doesn't matter and duplicate domains - /// are treated as equivalent. For example: - /// - 'domains %A, %B' and 'domains %B, %A' become the same canonical form - /// - 'domains %A, %A' becomes 'domains %A' - static DomainSequence canonicalizeDomains(ArrayRef domains) { - DomainSequence canonical(domains.begin(), domains.end()); - - // Sort domains deterministically. Use block arg order if all are block - // args. Fallback to opaque pointer comparison otherwise. Note: the opaque - // pointer comparison is only stable within a run of MLIR and should not be - // relied upon for determinism outside of a run. - llvm::sort(canonical, [](Value a, Value b) { - auto aArg = dyn_cast(a); - auto bArg = dyn_cast(b); - if (aArg && bArg) - return aArg.getArgNumber() < bArg.getArgNumber(); - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); - - // Remove duplicates to handle cases like 'domains %A, %A'. - canonical.erase(std::unique(canonical.begin(), canonical.end()), - canonical.end()); - - return canonical; - } +using InstanceIterator = InstanceGraphNode::UseIterator; +using InstanceRange = llvm::iterator_range; -public: - /// Get or create a domain variable for a value. - Value getDomainVar(Value value) { - auto it = valueToVar.find(value); - if (it != valueToVar.end()) - return it->second; - - // Create a new representative value for this domain variable. - Value representative = value; - valueToVar[value] = representative; - equivalenceClasses.insert(representative); - return representative; - } +/// From a domain info attribute, get the domain-type of a domain value at +/// index i. +StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { + if (info.empty()) + return nullptr; + auto ref = cast(info[i]); + return ref.getAttr(); +} - /// Union two domain variables. Returns success if successful, failure if - /// there's a domain conflict. With canonical representation, conflicts only - /// occur when connecting values with truly different domain sets (not just - /// different ordering). - [[nodiscard]] LogicalResult unifyDomains(Value var1, Value var2) { - Value rep1 = equivalenceClasses.getLeaderValue(var1); - Value rep2 = equivalenceClasses.getLeaderValue(var2); - - if (rep1 == rep2) - return success(); - - // Check for domain conflicts using canonical representation. - const DomainSequence &domains1 = getDomains(rep1); - const DomainSequence &domains2 = getDomains(rep2); - - // Since domains are already canonical, we can compare them directly. - // Conflicts only occur when connecting truly different domain sets. For - // example: 'domains %A' vs 'domains %B' is a conflict, but 'domains %A, %B' - // vs 'domains %B, %A' is not (same canonical form). - if (!domains1.empty() && !domains2.empty() && domains1 != domains2) { - // Conflict: both have different concrete domain sequences - return failure(); - } +/// From a domain info attribute, get the row of associated domains for a +/// hardware value at index i. +ArrayAttr getPortDomainAssociation(ArrayAttr info, size_t i) { + if (info.empty()) + return info; + return cast(info[i]); +} - // Merge the concrete domain sequences (both are already canonical) - const DomainSequence &mergedDomains = - !domains1.empty() ? domains1 : domains2; +/// Each declared domain in the circuit is assigned an index, based on the order +/// in which it appears. Domain associations for hardware values are represented +/// as a list of domains, sorted by the index of the domain type. +using DomainIndex = size_t; + +/// Information about the domains in the circuit. Able to map domains to their +/// domain-index, which in this pass is the canonical way to reference the type +/// of a domain. +struct CircuitDomainInfo { + static CircuitDomainInfo get(CircuitOp circuit) { + CircuitDomainInfo info; + info.processCircuit(circuit); + return info; + } - // Union the equivalence classes - equivalenceClasses.unionSets(rep1, rep2); + ArrayRef getDomains() const { return domainTable; } + size_t getNumDomains() const { return domainTable.size(); } + DomainOp getDomain(DomainIndex id) const { return domainTable[id]; } - // Set the merged domain sequence on the new representative - if (!mergedDomains.empty()) { - Value newRep = equivalenceClasses.getLeaderValue(rep1); - concreteDomains[newRep] = mergedDomains; + DomainIndex getDomainIndex(DomainOp op) const { + return indexTable.at(op.getNameAttr()); + } + DomainIndex getDomainIndex(StringAttr name) const { + return indexTable.at(name); + } + DomainIndex getDomainIndex(FlatSymbolRefAttr ref) const { + return getDomainIndex(ref.getAttr()); + } + DomainIndex getDomainIndex(ArrayAttr info, size_t i) const { + auto name = getDomainPortTypeName(info, i); + return getDomainIndex(name); + } + DomainIndex getDomainIndex(Value value) const { + assert(isa(value.getType())); + if (auto arg = dyn_cast(value)) { + auto *block = arg.getOwner(); + auto *owner = block->getParentOp(); + auto module = cast(owner); + auto info = module.getDomainInfoAttr(); + auto i = arg.getArgNumber(); + return getDomainIndex(info, i); } - return success(); + auto result = dyn_cast(value); + auto *owner = result.getOwner(); + auto instance = cast(owner); + auto info = instance.getDomainInfoAttr(); + auto i = result.getResultNumber(); + return getDomainIndex(info, i); } - /// Set the concrete domain sequence for a variable. - /// The domain sequence is automatically canonicalized to ensure consistent - /// representation regardless of input order or duplicates. - void setDomains(Value var, ArrayRef domains) { - Value rep = equivalenceClasses.getLeaderValue(var); - concreteDomains[rep] = canonicalizeDomains(domains); + void clear() { + domainTable.clear(); + indexTable.clear(); } - /// Get the concrete domain sequence for a variable. - /// Returns the canonical domain sequence (sorted, no duplicates). - const DomainSequence &getDomains(Value var) { - Value rep = equivalenceClasses.getLeaderValue(var); - auto it = concreteDomains.find(rep); - if (it != concreteDomains.end()) - return it->second; - - // Return empty sequence if no domains are set - static const DomainSequence emptySequence; - return emptySequence; + void processCircuit(CircuitOp circuit) { + clear(); + for (auto decl : circuit.getOps()) + processDomain(decl); + + for (auto [i, domain] : llvm::enumerate(domainTable)) + llvm::errs() << "domain " << i << " = " << domain << "\n"; } - /// Clear all state. - void clear() { - equivalenceClasses = llvm::EquivalenceClasses(); - concreteDomains.clear(); - valueToVar.clear(); + void processDomain(DomainOp op) { + auto index = domainTable.size(); + auto name = op.getNameAttr(); + domainTable.push_back(op); + indexTable.insert({name, index}); } -private: - llvm::EquivalenceClasses equivalenceClasses; - llvm::DenseMap concreteDomains; - llvm::DenseMap valueToVar; + SmallVector domainTable; + DenseMap indexTable; }; -/// Domain inference and checking pass implementation. -/// Uses canonical domain representation to allow domain order independence -/// and duplicate domain handling. -class InferDomainsPass - : public circt::firrtl::impl::InferDomainsBase { +/// The different sorts of terms in the unification engine. +enum class TermKind { + Variable, + Value, + Row, +}; -public: - InferDomainsPass() = default; +/// A term in the unification engine. +struct Term { + constexpr Term(TermKind kind) : kind(kind) {} + TermKind kind; +}; - void runOnOperation() override; +/// Helper to define a term kind. +template +struct TermBase : Term { + static bool classof(const Term *term) { return term->kind == K; } + TermBase() : Term(K) {} +}; - std::unique_ptr clonePass() const override { - return std::make_unique(); - } +/// An unknown value. +struct VariableTerm : public TermBase { + VariableTerm() : leader(nullptr) {} + VariableTerm(Term *leader) : leader(leader) {} + Term *leader; +}; -private: - /// Union-find structure for domain variables. - DomainUnionFind domainUF; +/// A concrete value defined in the IR. +struct ValueTerm : public TermBase { + ValueTerm(Value value) : value(value) {} + Value getValue() const { return value; } + Value value; +}; - /// Track if any domain crossing errors were found. - bool hasErrors = false; +/// A row of domains. +struct RowTerm : public TermBase { + RowTerm(ArrayRef elements) : elements(elements) {} + ArrayRef elements; +}; - /// Process a module and infer domains. - LogicalResult processModule(FModuleOp module); +template +T &operator<<(T &out, const Term &term); - /// Process domain constraints from operations. - void processOperation(Operation *op); +template +T &operator<<(T &out, const VariableTerm &term) { + return out << "var@" << (void *)&term << "{leader=" << term.leader << "}"; +} - /// Unify domains of two values. - [[nodiscard]] LogicalResult unifyDomains(Value lhs, Value rhs); +template +T &operator<<(T &out, const ValueTerm &term) { + return out << "value@" << (void *)&term << "{" << term.value << "}"; +} - /// Set explicit domain sequence for a value. - void setExplicitDomains(Value value, OperandRange domains); +template +T &operator<<(T &out, const RowTerm &term) { + out << "row@" << (void *)&term << "{"; + bool first = true; + for (auto *element : term.elements) { + if (!first) + out << ", "; + out << element; + first = false; + } + out << "}"; + return out; +} - /// Helper to add domain sequence notes to diagnostics. - void addDomainSequenceNotes(InFlightDiagnostic &diag, - const DomainUnionFind::DomainSequence &domains, - StringRef prefix); +template +T &operator<<(T &out, const Term &term) { + if (auto *var = dyn_cast(&term)) + return out << *var; + if (auto *val = dyn_cast(&term)) + return out << *val; + if (auto *row = dyn_cast(&term)) + return out << *row; + assert(0); + llvm_unreachable("unknown term"); + return out; +} - /// Helper to emit domain crossing error with detailed notes. - /// Sets hasErrors flag. - void emitDomainCrossingError(Operation *op, Value lhs, Value rhs, - StringRef errorMessage, StringRef lhsLabel, - StringRef rhsLabel); -}; +template +T &operator<<(T &out, const Term *term) { + if (!term) + return out << "null"; + return out << *term; +} -} // namespace +Term *find(Term *x) { + if (auto *var = dyn_cast(x)) { + if (var->leader == nullptr) + return var; -LogicalResult InferDomainsPass::unifyDomains(Value lhs, Value rhs) { - Value lhsVar = domainUF.getDomainVar(lhs); - Value rhsVar = domainUF.getDomainVar(rhs); + auto *leader = find(var->leader); + if (leader != var->leader) + var->leader = leader; + return leader; + } - return domainUF.unifyDomains(lhsVar, rhsVar); + return x; } -void InferDomainsPass::setExplicitDomains(Value value, OperandRange domains) { - Value domainVar = domainUF.getDomainVar(value); - // Convert OperandRange to SmallVector for setDomains. - llvm::SmallVector domainVec(domains.begin(), domains.end()); - domainUF.setDomains(domainVar, domainVec); +LogicalResult unify(Term *lhs, Term *rhs); + +LogicalResult unify(VariableTerm *x, Term *y) { + x->leader = y; + return success(); } -void InferDomainsPass::addDomainSequenceNotes( - InFlightDiagnostic &diag, const DomainUnionFind::DomainSequence &domains, - StringRef prefix) { - if (domains.empty()) - return; - - for (size_t i = 0; i < domains.size(); ++i) { - if (domains[i]) { - std::string message = prefix.str(); - if (domains.size() > 1) { - message += " (domain " + std::to_string(i + 1) + " of " + - std::to_string(domains.size()) + ")"; - } - message += " is in domain defined here"; - diag.attachNote(domains[i].getLoc()).append(message); - } +LogicalResult unify(ValueTerm *xv, Term *y) { + if (auto *yv = dyn_cast(y)) { + yv->leader = xv; + return success(); + } + if (auto *yv = dyn_cast(y)) { + return success(xv == yv); } + return failure(); } -void InferDomainsPass::emitDomainCrossingError(Operation *op, Value lhs, - Value rhs, - StringRef errorMessage, - StringRef lhsLabel, - StringRef rhsLabel) { - // Get the concrete domain sequences for error reporting - Value lhsVar = domainUF.getDomainVar(lhs); - Value rhsVar = domainUF.getDomainVar(rhs); - const auto &lhsDomains = domainUF.getDomains(lhsVar); - const auto &rhsDomains = domainUF.getDomains(rhsVar); - - auto diag = op->emitError(errorMessage); - addDomainSequenceNotes(diag, lhsDomains, lhsLabel); - addDomainSequenceNotes(diag, rhsDomains, rhsLabel); - hasErrors = true; -} - -void InferDomainsPass::processOperation(Operation *op) { - // Error on unhandled operations. - if (isa(op)) { - llvm::errs() << "InferDomains cannot yet handle " << op->getName() << "\n"; - return signalPassFailure(); +LogicalResult unify(RowTerm *lhsRow, Term *rhs) { + if (auto rhsVar = dyn_cast(rhs)) { + rhsVar->leader = lhsRow; + return success(); } - - // Handle unsafe domain casts - if (auto domainCastOp = dyn_cast(op)) { - if (!domainCastOp.getDomains().empty()) { - // Explicitly cast to specified domain sequence - setExplicitDomains(domainCastOp.getResult(), domainCastOp.getDomains()); + if (auto rhsRow = dyn_cast(rhs)) { + assert(lhsRow->elements.size() == rhsRow->elements.size()); + for (auto [x, y] : llvm::zip(lhsRow->elements, rhsRow->elements)) { + if (failed(unify(x, y))) + return failure(); } - return; + return success(); } - // For all operations (including connections), propagate domains from operands - // to results This is a conservative approach - all operands and results share - // domains - if (!op->getOperands().empty()) { - Value firstOperand = op->getOperand(0); - for (auto operand : op->getOperands()) { - if (failed(unifyDomains(firstOperand, operand))) { - emitDomainCrossingError(op, firstOperand, operand, - "illegal domain crossing in operation", - "first operand", "operand"); - } - } - for (auto result : op->getResults()) { - if (failed(unifyDomains(firstOperand, result))) { - emitDomainCrossingError(op, firstOperand, result, - "illegal domain crossing in operation", - "operand", "result"); - } - } + return failure(); +} + +LogicalResult unify(Term *lhs, Term *rhs) { + llvm::errs() << "unify x=" << *lhs << " y=" << *rhs << "\n"; + lhs = find(lhs); + rhs = find(rhs); + if (lhs == rhs) + return success(); + if (auto *lhsVar = dyn_cast(lhs)) + return unify(lhsVar, rhs); + if (auto *lhsVal = dyn_cast(lhs)) + return unify(lhsVal, rhs); + if (auto *lhsRow = dyn_cast(lhs)) + return unify(lhsRow, rhs); + return failure(); +} + +class InferModuleDomains { +public: + /// Run infer-domains on a module. + static LogicalResult run(const CircuitDomainInfo &, FModuleOp); + +private: + /// Initialize module-level state. + InferModuleDomains(const CircuitDomainInfo &); + + /// Execute on the given module. + LogicalResult operator()(FModuleOp); + + /// Record the domain associations of hardware ports, and record the + /// underlying value of output domain ports. + LogicalResult processPorts(FModuleOp); + + /// Record the domain associations of hardware, and record the underlying + /// value of domains, defined within the body of the module. + LogicalResult processBody(FModuleOp); + + /// Record the domain associations of any operands or results. + LogicalResult processOp(Operation *); + LogicalResult processOp(InstanceOp); + LogicalResult processOp(UnsafeDomainCastOp); + + /// Unify the associated domain rows of two terms. + LogicalResult unifyAssociations(Value, Value); + + /// Record a mapping from domain in the IR to its corresponding term. + void setTermForDomain(Value, Term *); + + /// Get the corresponding term for a domain in the IR. + Term *getTermForDomain(Value); + + /// Get the corresponding term for a domain in the IR, or null if unset. + Term *getOptTermForDomain(Value) const; + + /// Record a mapping from a hardware value in the IR to a term which + /// represents the row of domains it is associated with. + void setDomainAssociation(Value, Term *); + + /// Get the associated domain row, forced to be at least a row. + RowTerm *getDomainAssociationAsRow(Value); + + Term *getDomainAssociation(Value); + + /// Get the term which represents the row of domains associated with a + /// hardware value in the design. + Term *getOptDomainAssociation(Value) const; + + /// Allocate a row, where each domain is a variable. + RowTerm *allocateRow(); + + /// Allocate a row. + RowTerm *allocateRow(ArrayRef); + + /// Allocate a term. + template + T *allocate(Args &&...); + + /// Allocate an array of terms. If any terms were left null, automatically + /// replace them with a new variable. + ArrayRef allocateArray(ArrayRef); + + /// Information about the domains in a circuit. + const CircuitDomainInfo &circuitInfo; + + /// Term allocator. + llvm::BumpPtrAllocator allocator; + + /// Map from domains in the IR to their underlying term. + DenseMap termTable; + + /// A map from hardware values to their associated row of domains, as a term. + DenseMap associationTable; + + /// A boolean tracking if a non-fatal error occurred, or not. + bool ok = true; +}; + +LogicalResult InferModuleDomains::run(const CircuitDomainInfo &circuitInfo, + FModuleOp module) { + return InferModuleDomains(circuitInfo)(module); +} + +InferModuleDomains::InferModuleDomains(const CircuitDomainInfo &circuitInfo) + : circuitInfo(circuitInfo) {} + +LogicalResult InferModuleDomains::operator()(FModuleOp module) { + if (failed(processPorts(module))) + return failure(); + + if (failed(processBody(module))) + return failure(); + + for (auto association : associationTable) { + llvm::errs() << "association:\n"; + llvm::errs() << " " << association.first << "\n"; + llvm::errs() << " " << association.second << "\n"; } + + if (failed(updatePorts(module))) + return failure(); + + return llvm::success(ok); } -LogicalResult InferDomainsPass::processModule(FModuleOp module) { - LLVM_DEBUG(llvm::dbgs() << "Processing module: " << module.getName() << "\n"); +LogicalResult InferModuleDomains::processPorts(FModuleOp module) { + auto portDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); - // Process module ports - domain ports define explicit domains - for (auto [index, port] : llvm::enumerate(module.getPorts())) { - Value portValue = module.getArgument(index); + // Process module ports - domain ports define explicit domains. + DenseMap domainIndexTable; + for (size_t i = 0; i < numPorts; ++i) { + Value port = module.getArgument(i); // This is a domain port. - if (isa(port.type)) { - Value domainVar = domainUF.getDomainVar(portValue); - domainUF.setDomains(domainVar, {portValue}); + if (isa(port.getType())) { + auto index = circuitInfo.getDomainIndex(portDomainInfo, i); + domainIndexTable[i] = index; + if (module.getPortDirection(i) == Direction::In) { + setTermForDomain(port, allocate(port)); + } continue; } - // This is a port with explicit domain information. - auto domains = cast(port.domains); - if (domains.empty()) + // This is a port, which may have explicit domain information. + auto portDomains = getPortDomainAssociation(portDomainInfo, i); + if (portDomains.empty()) continue; - SmallVector domainPorts; - for (auto domain : domains) { - auto index = cast(domain).getUInt(); - domainPorts.push_back(module.getArgument(index)); + + SmallVector elements(circuitInfo.getNumDomains()); + for (auto domainPortIndexAttr : portDomains.getAsRange()) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainIndex = domainIndexTable[domainPortIndex]; + auto *term = getTermForDomain(module.getArgument(domainPortIndex)); + elements[domainIndex] = term; } - Value domainVar = domainUF.getDomainVar(portValue); - domainUF.setDomains(domainVar, domainPorts); + auto *row = allocateRow(elements); + setDomainAssociation(port, row); } - // Process all operations in the module - module.walk([&](Operation *op) { processOperation(op); }); + return success(); +} - // Check if any errors were found during processing - if (hasErrors) - return failure(); +LogicalResult InferModuleDomains::processBody(FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(processOp(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} - // Update domain information for all non-domain ports - SmallVector newDomainInfo; - bool anyUpdated = false; +LogicalResult InferModuleDomains::processOp(Operation *op) { + llvm::errs() << "process op: " << *op << "\n"; - for (auto [index, port] : llvm::enumerate(module.getPorts())) { - // Skip domain ports - they don't need domain information - if (isa(port.type)) { - newDomainInfo.push_back(port.domains - ? port.domains - : ArrayAttr::get(module.getContext(), {})); - continue; - } + if (auto instance = dyn_cast(op)) + return processOp(instance); + if (auto cast = dyn_cast(op)) + return processOp(cast); + + // For all operations (including connections), propagate domains from operands + // to results This is a conservative approach - all operands and results share + // the same domain associations. + Value lhs; + for (auto rhs : op->getOperands()) { + if (failed(unifyAssociations(lhs, rhs))) + return failure(); + lhs = rhs; + } + for (auto rhs : op->getResults()) { + if (failed(unifyAssociations(lhs, rhs))) + return failure(); + lhs = rhs; + } + return success(); +} + +LogicalResult InferModuleDomains::processOp(InstanceOp op) { + DenseMap portDomainIndexTable; + auto domainInfo = op.getDomainInfoAttr(); + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { + Value port = op.getResult(i); - // Get the inferred domains for this port - Value portValue = module.getArgument(index); - const auto &inferredDomains = domainUF.getDomains(portValue); - - // Convert domain values to domain indices - SmallVector domainIndices; - for (Value domain : inferredDomains) { - if (auto blockArg = dyn_cast(domain)) { - // This is a reference to a domain port - domainIndices.push_back(IntegerAttr::get( - IntegerType::get(module.getContext(), 32, IntegerType::Unsigned), - blockArg.getArgNumber())); + // This is a domain port. + if (isa(port.getType())) { + auto index = circuitInfo.getDomainIndex(domainInfo, i); + portDomainIndexTable[i] = index; + if (op.getPortDirection(i) == Direction::Out) { + setTermForDomain(port, allocate(port)); + } else { + setTermForDomain(port, allocate()); } + continue; } - ArrayAttr newDomains = ArrayAttr::get(module.getContext(), domainIndices); - - // Check if this is different from the existing domain information - if (!port.domains || port.domains != newDomains) { - anyUpdated = true; + // This is a port, which may have explicit domain information. + SmallVector associations(circuitInfo.getNumDomains()); + auto domains = cast(domainInfo).getAsRange(); + for (auto domainPortIndexAttr : domains) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainIndex = portDomainIndexTable[domainPortIndex]; + auto *term = getTermForDomain(op.getResult(domainPortIndex)); + associations[domainIndex] = term; } - newDomainInfo.push_back(newDomains); + // Since we are processing bottom-up, we must have complete domain info + // for each port on the instance. + for (auto *domain : associations) + assert(domain && "must have complete domain information."); + + setDomainAssociation(port, allocateRow(associations)); } - // Update the module's domain information if anything changed - if (anyUpdated) { - module->setAttr("domainInfo", - ArrayAttr::get(module.getContext(), newDomainInfo)); + return success(); +} + +LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { + auto domains = op.getDomains(); + if (domains.empty()) + return unifyAssociations(op.getInput(), op.getResult()); + + auto input = op.getInput(); + RowTerm *inputRow = getDomainAssociationAsRow(input); + SmallVector elements(inputRow->elements); + for (auto domain : op.getDomains()) { + auto index = circuitInfo.getDomainIndex(domain); + elements[index] = getTermForDomain(domain); } - LLVM_DEBUG({ - for (auto [index, port] : llvm::enumerate(module.getPorts())) { - llvm::dbgs() << " - port: " << port.getName() << "\n" - << " domains:\n"; - auto domains = domainUF.getDomains(module.getArgument(index)); - if (domains.empty()) { - llvm::dbgs() << " - inferred\n"; - continue; - } - for (auto domain : domains) { - if (auto port = dyn_cast(domain)) { - llvm::dbgs() << " - " << module.getPortName(port.getArgNumber()) - << "\n"; - continue; - } - } + auto *row = allocateRow(elements); + setDomainAssociation(op.getResult(), row); + return success(); +} + +LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { + llvm::errs() << " unify associations of:\n"; + llvm::errs() << " lhs=" << lhs << "\n"; + llvm::errs() << " rhs=" << rhs << "\n"; + + if (!lhs || !rhs) + return success(); + + auto *lhsTerm = getOptDomainAssociation(lhs); + auto *rhsTerm = getOptDomainAssociation(rhs); + + if (lhsTerm) { + if (rhsTerm) { + return unify(lhsTerm, rhsTerm); } - }); + setDomainAssociation(rhs, lhsTerm); + return success(); + } + + if (rhsTerm) { + setDomainAssociation(lhs, rhsTerm); + return success(); + } + auto *var = allocate(); + setDomainAssociation(lhs, var); + setDomainAssociation(rhs, var); return success(); } -void InferDomainsPass::runOnOperation() { - LLVM_DEBUG(debugPassHeader(this) << "\n"); +Term *InferModuleDomains::getTermForDomain(Value value) { + assert(isa(value.getType())); + if (auto *term = getOptTermForDomain(value)) + return term; + auto *term = allocate(); + setDomainAssociation(value, term); + return term; +} - // Clear state from any previous runs - domainUF.clear(); - hasErrors = false; +Term *InferModuleDomains::getOptTermForDomain(Value value) const { + assert(isa(value.getType())); + auto it = termTable.find(value); + if (it == termTable.end()) + return nullptr; + return find(it->second); +} - auto circuit = getOperation(); +void InferModuleDomains::setTermForDomain(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + assert(!termTable.contains(value)); + termTable.insert({value, term}); +} - // Process each module in the circuit - for (auto module : circuit.getOps()) { - if (failed(processModule(module))) { - signalPassFailure(); - return; - } +RowTerm *InferModuleDomains::getDomainAssociationAsRow(Value value) { + assert(isa(value.getType())); + auto *term = getOptDomainAssociation(value); + + // If the term is unknown, allocate a fresh row and set the association. + if (!term) { + auto *row = allocateRow(); + setDomainAssociation(value, row); + return row; } - // Signal failure if any domain crossing errors were found - if (hasErrors) { - signalPassFailure(); + // If the term is already a row, return it. + auto *row = dyn_cast(term); + if (row) + return row; + + // Otherwise, unify the term with a fresh row of domains. + row = allocateRow(); + auto result = unify(row, term); + assert(result.succeeded()); + return row; +} + +Term *InferModuleDomains::getDomainAssociation(Value value) { + auto *term = getOptDomainAssociation(value); + if (term) + return term; + term = allocate(); + setDomainAssociation(value, term); + return term; +} + +Term *InferModuleDomains::getOptDomainAssociation(Value value) const { + assert(isa(value.getType())); + auto it = associationTable.find(value); + if (it == associationTable.end()) + return nullptr; + return find(it->second); +} + +void InferModuleDomains::setDomainAssociation(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + term = find(term); + associationTable.insert({value, term}); + llvm::errs() << " set domain association: " << value << " -> " << term << "\n"; +} + +RowTerm *InferModuleDomains::allocateRow() { + SmallVector elements; + elements.resize(circuitInfo.getNumDomains()); + return allocateRow(elements); +} + +RowTerm *InferModuleDomains::allocateRow(ArrayRef elements) { + auto ds = allocateArray(elements); + return allocate(ds); +} + +template +T *InferModuleDomains::allocate(Args &&...args) { + static_assert(std::is_base_of_v, "T must be a term"); + return new (allocator) T(std::forward(args)...); +} + +ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { + auto size = elements.size(); + if (size == 0) + return {}; + + auto result = allocator.Allocate(size); + llvm::uninitialized_copy(elements, result); + for (size_t i = 0; i < size; ++i) + if (!result[i]) + result[i] = allocate(); + + return ArrayRef(result, elements.size()); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Domain inference and checking pass implementation. +/// Uses canonical domain representation to allow domain order independence +/// and duplicate domain handling. +class InferDomainsPass + : public circt::firrtl::impl::InferDomainsBase { + +public: + InferDomainsPass() = default; + + /// Copy the pass by allocating fresh state. + InferDomainsPass(const InferDomainsPass &) : InferDomainsPass() {} + + void runOnOperation() override; + +private: + /// Process a module and infer domains. + LogicalResult processModule(const CircuitDomainInfo &, FModuleOp, + InstanceRange); +}; + +} // namespace + +LogicalResult +InferDomainsPass::processModule(const CircuitDomainInfo &circuitInfo, + FModuleOp module, InstanceRange instances) { + LLVM_DEBUG(llvm::dbgs() << "Processing module: " << module.getName() << "\n"); + return InferModuleDomains::run(circuitInfo, module); + + // Insert domain ports if needed. + // TODO: + + // // Update domain information for all non-domain ports + // SmallVector newDomainInfo; + // bool anyUpdated = false; + + // for (auto [index, port] : llvm::enumerate(module.getPorts())) { + // // Skip domain ports - they don't need domain information + // if (isa(port.type)) { + // newDomainInfo.push_back(port.domains + // ? port.domains + // : ArrayAttr::get(module.getContext(), {})); + // continue; + // } + + // // Get the inferred domains for this port + // Value portValue = module.getArgument(index); + // const auto &domains = domainUF.getDomains(portValue); + + // // Convert domain values to domain indices + // SmallVector domainIndices; + // for (auto *term : domains) { + // auto *value = dyn_cast(find(term)); + // if (!value) + // continue; + // // TODO + + // auto ir = value->value; + // if (auto blockArg = dyn_cast(ir)) { + // // This is a reference to a domain port + // domainIndices.push_back(IntegerAttr::get( + // IntegerType::get(module.getContext(), 32, IntegerType::Unsigned), + // blockArg.getArgNumber())); + // } + // } + + // ArrayAttr newDomains = ArrayAttr::get(module.getContext(), + // domainIndices); + + // // Check if this is different from the existing domain information + // if (!port.domains || port.domains != newDomains) { + // anyUpdated = true; + // } + + // newDomainInfo.push_back(newDomains); +} + +// Update the module's domain information if anything changed +// if (anyUpdated) { +// module->setAttr("domainInfo", +// ArrayAttr::get(module.getContext(), newDomainInfo)); +// } + +// LLVM_DEBUG({ +// for (auto [index, port] : llvm::enumerate(module.getPorts())) { +// llvm::dbgs() << " - port: " << port.getName() << "\n" +// << " domains:\n"; +// auto domains = domainUF.getDomains(module.getArgument(index)); +// if (domains.empty()) { +// llvm::dbgs() << " - inferred\n"; +// continue; +// } +// for (auto term : domains) { +// auto leader = find(term); +// Value domain; +// if (auto value = dyn_cast(leader)) +// domain = value->value; +// if (auto port = dyn_cast(domain)) { +// llvm::dbgs() << " - " << +// module.getPortName(port.getArgNumber()) +// << "\n"; +// continue; +// } +// } +// } +// }); + +// return success(); +// } + +void InferDomainsPass::runOnOperation() { + LLVM_DEBUG(debugPassHeader(this) << "\n"); + auto circuit = getOperation(); + auto &instanceGraph = getAnalysis(); + + auto circuitInfo = CircuitDomainInfo::get(circuit); + + // Process each module in the circuit. + DenseSet visited; + for (auto *root : instanceGraph) { + for (auto *node : llvm::post_order_ext(root, visited)) { + if (auto module = dyn_cast(node->getModule())) + if (failed(processModule(circuitInfo, module, node->uses()))) { + signalPassFailure(); + return; + } + } } LLVM_DEBUG(debugFooter() << "\n"); From 382a11fc6bc41ed08c5cae22626883ecb5d1a7d3 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 24 Sep 2025 11:45:10 -0400 Subject: [PATCH 06/20] Got port domain association inference and module generalization working --- .../FIRRTL/Transforms/InferDomains.cpp | 182 +++++++++++++++++- 1 file changed, 177 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 42e38814b619..a04bd20ba032 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -44,6 +44,7 @@ namespace { using InstanceIterator = InstanceGraphNode::UseIterator; using InstanceRange = llvm::iterator_range; +using PortInsertions = SmallVector>; /// From a domain info attribute, get the domain-type of a domain value at /// index i. @@ -289,6 +290,12 @@ LogicalResult unify(Term *lhs, Term *rhs) { return failure(); } +void solve(Term *lhs, Term *rhs) { + auto result = unify(lhs, rhs); + (void)result; + assert(result.succeeded()); +} + class InferModuleDomains { public: /// Run infer-domains on a module. @@ -314,6 +321,19 @@ class InferModuleDomains { LogicalResult processOp(InstanceOp); LogicalResult processOp(UnsafeDomainCastOp); + LogicalResult updateModule(FModuleOp); + + /// After generalizing the module, all domains should be solved. Reflect the + /// solved domain associations into the port domain info attribute. + LogicalResult updatePortDomainAssociations(FModuleOp); + + /// Add domain ports for any uninferred domains associated to hardware. + /// Returns the inserted ports, which will be used later to generalize the + /// instances of this module. + PortInsertions generalizeModule(FModuleOp); + + void generalizeInstance(InstanceOp, const PortInsertions &); + /// Unify the associated domain rows of two terms. LogicalResult unifyAssociations(Value, Value); @@ -390,7 +410,7 @@ LogicalResult InferModuleDomains::operator()(FModuleOp module) { llvm::errs() << " " << association.second << "\n"; } - if (failed(updatePorts(module))) + if (failed(updateModule(module))) return failure(); return llvm::success(ok); @@ -528,6 +548,153 @@ LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { return success(); } +LogicalResult InferModuleDomains::updateModule(FModuleOp op) { + auto insertions = generalizeModule(op); + + if (failed(updatePortDomainAssociations(op))) + return failure(); + + return success(); +} + +LogicalResult +InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { + // At this point, all domain variables mentioned in ports have been + // solved by generalizing the module (adding input domain ports). Now, we have + // to form the new port domain information for the module by examining the + // solutions to the associated domains of each port. + auto *context = module.getContext(); + auto builder = OpBuilder::atBlockBegin(module.getBodyBlock()); + auto oldDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + SmallVector domainInfo(numPorts); + + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + // By default, copy the old domain info over. + domainInfo[i] = oldDomainInfo[i]; + + // If the port is an output domain, we may need to drive the output with + // a value. If we don't know what value to drive to the port, error. + if (isa(type)) { + if (module.getPortDirection(i) == Direction::Out) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + + if (!driven) { + auto *term = getTermForDomain(port); + term = find(term); + if (auto *val = dyn_cast(term)) { + auto loc = port.getLoc(); + auto type = getDomainPortTypeName(oldDomainInfo, i); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value, type); + } else { + return module.emitError() << "unable to infer output domain value"; + } + } + } + continue; + } + + if (isa(type)) { + SmallVector associations(circuitInfo.getNumDomains()); + auto *row = getDomainAssociationAsRow(port); + for (auto [typeID, term] : llvm::enumerate(row->elements)) { + auto *domain = find(term); + auto *val = dyn_cast(domain); + if (!val) + return module.emitError() << "unable to infer domain for port"; + auto arg = cast(val->value); + auto idx = arg.getArgNumber(); + associations[typeID] = IntegerAttr::get( + IntegerType::get(context, 32, IntegerType::Unsigned), idx); + } + domainInfo[i] = ArrayAttr::get(context, associations); + continue; + } + } + + auto domainInfoAttr = ArrayAttr::get(module.getContext(), domainInfo); + module.setDomainInfoAttr(domainInfoAttr); + return success(); +} + +PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { + // If the port is hardware, we have to check the associated row of + // domains. If any associated domain is a variable, we solve the variable + // by generalizing the module with an additional input domain port. If any + // associated domain is defined internally to the module, we have to add + // an output domain port, to allow the domain to escape. + SmallVector> insertions; + DenseMap solutions; + size_t inserted = 0; + auto numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + if (!isa(type)) + continue; + + auto *row = getDomainAssociationAsRow(port); + for (auto [typeID, term] : llvm::enumerate(row->elements)) { + auto *domain = find(term); + auto *var = dyn_cast(domain); + + if (!var) + continue; + + if (solutions.contains(var)) + continue; + + // insert a new port for the variable. + auto domainDecl = circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::In; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending solution. + auto solutionPortIndex = inserted + portInsertionPoint; + solutions[var] = solutionPortIndex; + ++inserted; + } + } + + // Put the ports in place. + module.insertPorts(insertions); + + llvm::errs() << "generalization complete\n"; + llvm::errs() << module << "\n"; + + // Solve the variables. + for (auto [var, portIndex] : solutions) { + auto *solution = allocate(module.getArgument(portIndex)); + solve(var, solution); + } + + return insertions; +} + LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { llvm::errs() << " unify associations of:\n"; llvm::errs() << " lhs=" << lhs << "\n"; @@ -536,6 +703,9 @@ LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { if (!lhs || !rhs) return success(); + if (lhs == rhs) + return success(); + auto *lhsTerm = getOptDomainAssociation(lhs); auto *rhsTerm = getOptDomainAssociation(rhs); @@ -627,7 +797,8 @@ void InferModuleDomains::setDomainAssociation(Value value, Term *term) { assert(term); term = find(term); associationTable.insert({value, term}); - llvm::errs() << " set domain association: " << value << " -> " << term << "\n"; + llvm::errs() << " set domain association: " << value << " -> " << term + << "\n"; } RowTerm *InferModuleDomains::allocateRow() { @@ -703,7 +874,8 @@ InferDomainsPass::processModule(const CircuitDomainInfo &circuitInfo, // if (isa(port.type)) { // newDomainInfo.push_back(port.domains // ? port.domains - // : ArrayAttr::get(module.getContext(), {})); + // : ArrayAttr::get(module.getContext(), + // {})); // continue; // } @@ -723,8 +895,8 @@ InferDomainsPass::processModule(const CircuitDomainInfo &circuitInfo, // if (auto blockArg = dyn_cast(ir)) { // // This is a reference to a domain port // domainIndices.push_back(IntegerAttr::get( - // IntegerType::get(module.getContext(), 32, IntegerType::Unsigned), - // blockArg.getArgNumber())); + // IntegerType::get(module.getContext(), 32, + // IntegerType::Unsigned), blockArg.getArgNumber())); // } // } From 366a640fd26f23c97d5f9b818a6e1beaf83790ff Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 24 Sep 2025 17:50:05 -0400 Subject: [PATCH 07/20] Get output domain ports semi working --- .../FIRRTL/Transforms/InferDomains.cpp | 412 ++++++++++-------- 1 file changed, 235 insertions(+), 177 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index a04bd20ba032..511196f01886 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -21,13 +21,12 @@ #include "circt/Support/Debug.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/TrailingObjects.h" #define DEBUG_TYPE "firrtl-infer-domains" +#undef NDEBUG namespace circt { namespace firrtl { @@ -38,17 +37,18 @@ namespace firrtl { using namespace circt; using namespace firrtl; -using llvm::TrailingObjects; - -namespace { using InstanceIterator = InstanceGraphNode::UseIterator; using InstanceRange = llvm::iterator_range; using PortInsertions = SmallVector>; +//====-------------------------------------------------------------------------- +// Helpers for working with module or instance domain info. +//====-------------------------------------------------------------------------- + /// From a domain info attribute, get the domain-type of a domain value at /// index i. -StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { +static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { if (info.empty()) return nullptr; auto ref = cast(info[i]); @@ -57,12 +57,16 @@ StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { /// From a domain info attribute, get the row of associated domains for a /// hardware value at index i. -ArrayAttr getPortDomainAssociation(ArrayAttr info, size_t i) { +static ArrayAttr getPortDomainAssociation(ArrayAttr info, size_t i) { if (info.empty()) return info; return cast(info[i]); } +//====-------------------------------------------------------------------------- +// CircuitDomainInfo: Information about the domains declared in a circuit. +//====-------------------------------------------------------------------------- + /// Each declared domain in the circuit is assigned an index, based on the order /// in which it appears. Domain associations for hardware values are represented /// as a list of domains, sorted by the index of the domain type. @@ -71,6 +75,7 @@ using DomainIndex = size_t; /// Information about the domains in the circuit. Able to map domains to their /// domain-index, which in this pass is the canonical way to reference the type /// of a domain. +namespace { struct CircuitDomainInfo { static CircuitDomainInfo get(CircuitOp circuit) { CircuitDomainInfo info; @@ -138,6 +143,13 @@ struct CircuitDomainInfo { SmallVector domainTable; DenseMap indexTable; }; +} // namespace + +//====-------------------------------------------------------------------------- +// Terms: Syntax for unifying domain and domain-rows. +//====-------------------------------------------------------------------------- + +namespace { /// The different sorts of terms in the unification engine. enum class TermKind { @@ -296,6 +308,13 @@ void solve(Term *lhs, Term *rhs) { assert(result.succeeded()); } +} // namespace + +//====-------------------------------------------------------------------------- +// InferModuleDomains: Primary workhorse for inferring domains on modules. +//====-------------------------------------------------------------------------- + +namespace { class InferModuleDomains { public: /// Run infer-domains on a module. @@ -327,6 +346,12 @@ class InferModuleDomains { /// solved domain associations into the port domain info attribute. LogicalResult updatePortDomainAssociations(FModuleOp); + /// After updating the port domain associations, walk the body of the module + /// to fix up any child instance modules. + LogicalResult updateDomainAssociationsInBody(FModuleOp); + LogicalResult updateOpDomainAssociations(Operation *); + LogicalResult updateOpDomainAssociations(InstanceOp); + /// Add domain ports for any uninferred domains associated to hardware. /// Returns the inserted ports, which will be used later to generalize the /// instances of this module. @@ -385,9 +410,13 @@ class InferModuleDomains { /// A map from hardware values to their associated row of domains, as a term. DenseMap associationTable; + /// A map from local domain definition to its export, as a port. + DenseMap exportTable; + /// A boolean tracking if a non-fatal error occurred, or not. bool ok = true; }; +} // namespace LogicalResult InferModuleDomains::run(const CircuitDomainInfo &circuitInfo, FModuleOp module) { @@ -398,6 +427,10 @@ InferModuleDomains::InferModuleDomains(const CircuitDomainInfo &circuitInfo) : circuitInfo(circuitInfo) {} LogicalResult InferModuleDomains::operator()(FModuleOp module) { + llvm::errs() << "================================================\n"; + llvm::errs() << "infer module domains: " << module.getModuleName() << "\n"; + llvm::errs() << "================================================\n"; + if (failed(processPorts(module))) return failure(); @@ -492,39 +525,44 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { } LogicalResult InferModuleDomains::processOp(InstanceOp op) { - DenseMap portDomainIndexTable; + DenseMap domainPortTypeIDTable; auto domainInfo = op.getDomainInfoAttr(); for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { Value port = op.getResult(i); - // This is a domain port. + llvm::errs() << "handling instance port: " << port << "\n"; + if (isa(port.getType())) { - auto index = circuitInfo.getDomainIndex(domainInfo, i); - portDomainIndexTable[i] = index; + auto typeID = circuitInfo.getDomainIndex(domainInfo, i); + domainPortTypeIDTable[i] = typeID; if (op.getPortDirection(i) == Direction::Out) { setTermForDomain(port, allocate(port)); - } else { - setTermForDomain(port, allocate()); } continue; } - // This is a port, which may have explicit domain information. - SmallVector associations(circuitInfo.getNumDomains()); - auto domains = cast(domainInfo).getAsRange(); - for (auto domainPortIndexAttr : domains) { + if (!isa(port.getType())) + continue; + + // This is a port, which may have explicit domain information. Associate the + // port with a row of domains, where each element is derived from the domain + // associations recorded in the domain info attribute of the instance. + SmallVector elements(circuitInfo.getNumDomains()); + auto associations = + getPortDomainAssociation(domainInfo, i).getAsRange(); + for (auto domainPortIndexAttr : associations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); - auto domainIndex = portDomainIndexTable[domainPortIndex]; + auto typeID = domainPortTypeIDTable[domainPortIndex]; auto *term = getTermForDomain(op.getResult(domainPortIndex)); - associations[domainIndex] = term; + elements[typeID] = term; } // Since we are processing bottom-up, we must have complete domain info // for each port on the instance. - for (auto *domain : associations) - assert(domain && "must have complete domain information."); + for (auto *element : elements) + assert(element && "must have complete domain information."); - setDomainAssociation(port, allocateRow(associations)); + setDomainAssociation(port, allocateRow(elements)); } return success(); @@ -554,6 +592,9 @@ LogicalResult InferModuleDomains::updateModule(FModuleOp op) { if (failed(updatePortDomainAssociations(op))) return failure(); + if (failed(updateDomainAssociationsInBody(op))) + return failure(); + return success(); } @@ -590,18 +631,24 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { } } + // Get the underlying value of the output port. + auto *term = getTermForDomain(port); + term = find(term); + auto *val = dyn_cast(term); + if (!val) + return module.emitError() << "unable to infer output domain value"; + + // If the output port is not driven, drive it. if (!driven) { - auto *term = getTermForDomain(port); - term = find(term); - if (auto *val = dyn_cast(term)) { - auto loc = port.getLoc(); - auto type = getDomainPortTypeName(oldDomainInfo, i); - auto value = val->value; - DomainDefineOp::create(builder, loc, port, value, type); - } else { - return module.emitError() << "unable to infer output domain value"; - } + auto loc = port.getLoc(); + auto typeName = getDomainPortTypeName(oldDomainInfo, i); + auto typeRef = FlatSymbolRefAttr::get(typeName); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value, typeRef); } + + // Record the output port as an export of the underlying value. + exportTable.insert({val->value, port}); } continue; } @@ -614,7 +661,11 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { auto *val = dyn_cast(domain); if (!val) return module.emitError() << "unable to infer domain for port"; - auto arg = cast(val->value); + + auto arg = dyn_cast(val->value); + if (!arg) + arg = exportTable.at(val->value); + auto idx = arg.getArgNumber(); associations[typeID] = IntegerAttr::get( IntegerType::get(context, 32, IntegerType::Unsigned), idx); @@ -636,7 +687,9 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { // associated domain is defined internally to the module, we have to add // an output domain port, to allow the domain to escape. SmallVector> insertions; - DenseMap solutions; + DenseMap pendingSolutions; + llvm::MapVector pendingExports; + size_t inserted = 0; auto numPorts = module.getNumPorts(); for (size_t i = 0; i < numPorts; ++i) { @@ -649,34 +702,70 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { auto *row = getDomainAssociationAsRow(port); for (auto [typeID, term] : llvm::enumerate(row->elements)) { auto *domain = find(term); - auto *var = dyn_cast(domain); - - if (!var) - continue; - - if (solutions.contains(var)) - continue; - - // insert a new port for the variable. - auto domainDecl = circuitInfo.getDomain(typeID); - auto domainName = domainDecl.getNameAttr(); - - auto portInsertionPoint = i; - auto portName = domainName; - auto portType = DomainType::get(module.getContext()); - auto portDirection = Direction::In; - auto portSym = StringAttr(); - auto portLoc = port.getLoc(); - auto portAnnos = std::nullopt; - auto portDomainInfo = FlatSymbolRefAttr::get(domainName); - PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, - portAnnos, portDomainInfo); - insertions.push_back({portInsertionPoint, portInfo}); - - // Record the pending solution. - auto solutionPortIndex = inserted + portInsertionPoint; - solutions[var] = solutionPortIndex; - ++inserted; + if (auto *val = dyn_cast(domain)) { + // If the domain value is defined inside the module body, we must output + // export the domain, so it may appear in the signature of the + // module. + auto result = dyn_cast(val->value); + if (!result) + continue; + + // The domain is defined internally. If there is already an aliasing + // port, we are done. + if (exportTable.contains(val->value)) + continue; + + // If there is already a pending export, we are also done. + if (pendingExports.contains(val->value)) + continue; + + // We must insert a new output domain port. + auto domainDecl = circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::Out; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending export. + auto exportedPortIndex = inserted + portInsertionPoint; + pendingExports[val->value] = exportedPortIndex; + ++inserted; + } + + if (auto *var = dyn_cast(domain)) { + if (pendingSolutions.contains(var)) + continue; + + // insert a new input domain port for the variable. + auto domainDecl = circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::In; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending solution. + auto solutionPortIndex = inserted + portInsertionPoint; + pendingSolutions[var] = solutionPortIndex; + ++inserted; + } } } @@ -687,14 +776,86 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { llvm::errs() << module << "\n"; // Solve the variables. - for (auto [var, portIndex] : solutions) { + for (auto [var, portIndex] : pendingSolutions) { auto *solution = allocate(module.getArgument(portIndex)); solve(var, solution); } + // Drive the exports. + auto domainInfo = module.getDomainInfoAttr(); + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + + for (auto [value, portIndex] : pendingExports) { + auto port = module.getArgument(portIndex); + auto typeName = getDomainPortTypeName(domainInfo, portIndex); + auto typeNameRef = FlatSymbolRefAttr::get(typeName); + DomainDefineOp::create(builder, port.getLoc(), port, value, typeNameRef); + exportTable[value] = port; + setTermForDomain(port, allocate(value)); + } + return insertions; } +LogicalResult +InferModuleDomains::updateDomainAssociationsInBody(FModuleOp module) { + auto result = success(); + module.getBodyBlock()->walk([&](Operation *op) -> WalkResult { + if (failed(updateOpDomainAssociations(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +LogicalResult InferModuleDomains::updateOpDomainAssociations(Operation *op) { + if (auto inst = dyn_cast(op)) + return updateOpDomainAssociations(inst); + return success(); +} + +LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { + auto *context = op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(op); + auto domainInfo = op.getDomainInfoAttr(); + auto numPorts = op->getNumResults(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = op.getResult(i); + auto type = port.getType(); + auto direction = op.getPortDirection(i); + if (isa(type)) { + if (direction == Direction::In) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + if (!driven) { + auto *term = getTermForDomain(port); + term = find(term); + if (auto *val = dyn_cast(term)) { + auto loc = port.getLoc(); + auto typeName = getDomainPortTypeName(domainInfo, i); + auto typeRef = FlatSymbolRefAttr::get(typeName); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value, typeRef); + } else { + return op.emitError() << "unable to infer input domain value"; + } + } + } + } + } + return success(); +} + LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { llvm::errs() << " unify associations of:\n"; llvm::errs() << " lhs=" << lhs << "\n"; @@ -733,7 +894,7 @@ Term *InferModuleDomains::getTermForDomain(Value value) { if (auto *term = getOptTermForDomain(value)) return term; auto *term = allocate(); - setDomainAssociation(value, term); + setTermForDomain(value, term); return term; } @@ -823,7 +984,7 @@ ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { if (size == 0) return {}; - auto result = allocator.Allocate(size); + auto *result = allocator.Allocate(size); llvm::uninitialized_copy(elements, result); for (size_t i = 0; i < size; ++i) if (!result[i]) @@ -832,136 +993,33 @@ ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { return ArrayRef(result, elements.size()); } -//////////////////////////////////////////////////////////////////////////////// +//===--------------------------------------------------------------------------- +// InferDomainsPass: Top-level pass implementation. +//===--------------------------------------------------------------------------- -/// Domain inference and checking pass implementation. -/// Uses canonical domain representation to allow domain order independence -/// and duplicate domain handling. -class InferDomainsPass +namespace { +struct InferDomainsPass : public circt::firrtl::impl::InferDomainsBase { - -public: - InferDomainsPass() = default; - - /// Copy the pass by allocating fresh state. - InferDomainsPass(const InferDomainsPass &) : InferDomainsPass() {} - void runOnOperation() override; - -private: - /// Process a module and infer domains. - LogicalResult processModule(const CircuitDomainInfo &, FModuleOp, - InstanceRange); }; - } // namespace -LogicalResult -InferDomainsPass::processModule(const CircuitDomainInfo &circuitInfo, - FModuleOp module, InstanceRange instances) { - LLVM_DEBUG(llvm::dbgs() << "Processing module: " << module.getName() << "\n"); - return InferModuleDomains::run(circuitInfo, module); - - // Insert domain ports if needed. - // TODO: - - // // Update domain information for all non-domain ports - // SmallVector newDomainInfo; - // bool anyUpdated = false; - - // for (auto [index, port] : llvm::enumerate(module.getPorts())) { - // // Skip domain ports - they don't need domain information - // if (isa(port.type)) { - // newDomainInfo.push_back(port.domains - // ? port.domains - // : ArrayAttr::get(module.getContext(), - // {})); - // continue; - // } - - // // Get the inferred domains for this port - // Value portValue = module.getArgument(index); - // const auto &domains = domainUF.getDomains(portValue); - - // // Convert domain values to domain indices - // SmallVector domainIndices; - // for (auto *term : domains) { - // auto *value = dyn_cast(find(term)); - // if (!value) - // continue; - // // TODO - - // auto ir = value->value; - // if (auto blockArg = dyn_cast(ir)) { - // // This is a reference to a domain port - // domainIndices.push_back(IntegerAttr::get( - // IntegerType::get(module.getContext(), 32, - // IntegerType::Unsigned), blockArg.getArgNumber())); - // } - // } - - // ArrayAttr newDomains = ArrayAttr::get(module.getContext(), - // domainIndices); - - // // Check if this is different from the existing domain information - // if (!port.domains || port.domains != newDomains) { - // anyUpdated = true; - // } - - // newDomainInfo.push_back(newDomains); -} - -// Update the module's domain information if anything changed -// if (anyUpdated) { -// module->setAttr("domainInfo", -// ArrayAttr::get(module.getContext(), newDomainInfo)); -// } - -// LLVM_DEBUG({ -// for (auto [index, port] : llvm::enumerate(module.getPorts())) { -// llvm::dbgs() << " - port: " << port.getName() << "\n" -// << " domains:\n"; -// auto domains = domainUF.getDomains(module.getArgument(index)); -// if (domains.empty()) { -// llvm::dbgs() << " - inferred\n"; -// continue; -// } -// for (auto term : domains) { -// auto leader = find(term); -// Value domain; -// if (auto value = dyn_cast(leader)) -// domain = value->value; -// if (auto port = dyn_cast(domain)) { -// llvm::dbgs() << " - " << -// module.getPortName(port.getArgNumber()) -// << "\n"; -// continue; -// } -// } -// } -// }); - -// return success(); -// } - void InferDomainsPass::runOnOperation() { LLVM_DEBUG(debugPassHeader(this) << "\n"); auto circuit = getOperation(); auto &instanceGraph = getAnalysis(); - auto circuitInfo = CircuitDomainInfo::get(circuit); - - // Process each module in the circuit. DenseSet visited; for (auto *root : instanceGraph) { for (auto *node : llvm::post_order_ext(root, visited)) { - if (auto module = dyn_cast(node->getModule())) - if (failed(processModule(circuitInfo, module, node->uses()))) { + auto *op = node->getModule(); + if (auto module = llvm::dyn_cast_if_present(op)) { + if (failed(InferModuleDomains::run(circuitInfo, module))) { signalPassFailure(); return; } + } } } - LLVM_DEBUG(debugFooter() << "\n"); } From 13f61927585fa1254ddef0e1e206678f8b74117f Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 1 Oct 2025 16:30:56 -0400 Subject: [PATCH 08/20] WIP: Another checkin --- .../FIRRTL/Transforms/InferDomains.cpp | 428 +++++++++++++----- test/Dialect/FIRRTL/infer-domains-errors.mlir | 59 +-- test/Dialect/FIRRTL/infer-domains.mlir | 109 ++--- 3 files changed, 390 insertions(+), 206 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 511196f01886..bffdbbb3f151 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -17,12 +17,15 @@ #include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/FIRRTL/FIRRTLOps.h" +#include "circt/Dialect/FIRRTL/FIRRTLUtils.h" #include "circt/Dialect/FIRRTL/Passes.h" #include "circt/Support/Debug.h" +#include "circt/Support/Namespace.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "firrtl-infer-domains" @@ -63,6 +66,19 @@ static ArrayAttr getPortDomainAssociation(ArrayAttr info, size_t i) { return cast(info[i]); } +/// Return true if the value is a port on the module. +static bool isPort(FModuleOp module, BlockArgument arg) { + return arg.getOwner()->getParentOp() == module; +} + +/// Return true if the value is a port on the module. +static bool isPort(FModuleOp module, Value value) { + auto arg = dyn_cast(value); + if (!arg) + return false; + return isPort(module, arg); +} + //====-------------------------------------------------------------------------- // CircuitDomainInfo: Information about the domains declared in a circuit. //====-------------------------------------------------------------------------- @@ -128,9 +144,6 @@ struct CircuitDomainInfo { clear(); for (auto decl : circuit.getOps()) processDomain(decl); - - for (auto [i, domain] : llvm::enumerate(domainTable)) - llvm::errs() << "domain " << i << " = " << domain << "\n"; } void processDomain(DomainOp op) { @@ -191,54 +204,63 @@ struct RowTerm : public TermBase { ArrayRef elements; }; -template -T &operator<<(T &out, const Term &term); +/// A helper for assigning low numeric IDs to variables for user-facing output. +struct VariableIDTable { + size_t get(VariableTerm *term) { + auto [it, inserted] = table.insert({term, table.size() + 1}); + return it->second; + } -template -T &operator<<(T &out, const VariableTerm &term) { - return out << "var@" << (void *)&term << "{leader=" << term.leader << "}"; + DenseMap table; +}; + +#ifndef NDEBUG + +raw_ostream &dump(llvm::raw_ostream &out, const Term *term); + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const VariableTerm *term) { + return out << "var@" << (void *)term << "{leader=" << term->leader << "}"; } -template -T &operator<<(T &out, const ValueTerm &term) { - return out << "value@" << (void *)&term << "{" << term.value << "}"; +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const ValueTerm *term) { + return out << "val@" << term << "{" << term->value << "}"; } -template -T &operator<<(T &out, const RowTerm &term) { - out << "row@" << (void *)&term << "{"; +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const RowTerm *term) { + out << "row@" << term << "{"; bool first = true; - for (auto *element : term.elements) { + for (auto *element : term->elements) { if (!first) out << ", "; - out << element; + dump(out, element); first = false; } out << "}"; return out; } -template -T &operator<<(T &out, const Term &term) { - if (auto *var = dyn_cast(&term)) - return out << *var; - if (auto *val = dyn_cast(&term)) - return out << *val; - if (auto *row = dyn_cast(&term)) - return out << *row; - assert(0); - llvm_unreachable("unknown term"); - return out; -} - -template -T &operator<<(T &out, const Term *term) { +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const Term *term) { if (!term) return out << "null"; - return out << *term; + if (auto *var = dyn_cast(term)) + return dump(out, var); + if (auto *val = dyn_cast(term)) + return dump(out, val); + if (auto *row = dyn_cast(term)) + return dump(out, row); + llvm_unreachable("unknown term"); } +#endif // DEBUG +// NOLINTNEXTLINE(misc-no-recursion) Term *find(Term *x) { + if (!x) + return nullptr; + if (auto *var = dyn_cast(x)) { if (var->leader == nullptr) return var; @@ -270,12 +292,13 @@ LogicalResult unify(ValueTerm *xv, Term *y) { return failure(); } +// NOLINTNEXTLINE(misc-no-recursion) LogicalResult unify(RowTerm *lhsRow, Term *rhs) { - if (auto rhsVar = dyn_cast(rhs)) { + if (auto *rhsVar = dyn_cast(rhs)) { rhsVar->leader = lhsRow; return success(); } - if (auto rhsRow = dyn_cast(rhs)) { + if (auto *rhsRow = dyn_cast(rhs)) { assert(lhsRow->elements.size() == rhsRow->elements.size()); for (auto [x, y] : llvm::zip(lhsRow->elements, rhsRow->elements)) { if (failed(unify(x, y))) @@ -287,8 +310,12 @@ LogicalResult unify(RowTerm *lhsRow, Term *rhs) { return failure(); } +// NOLINTNEXTLINE(misc-no-recursion) LogicalResult unify(Term *lhs, Term *rhs) { - llvm::errs() << "unify x=" << *lhs << " y=" << *rhs << "\n"; + LLVM_DEBUG(auto &out = llvm::errs(); out << "unify x="; dump(out, lhs); + out << " y="; dump(out, rhs); out << "\n";); + if (!lhs || !rhs) + return success(); lhs = find(lhs); rhs = find(rhs); if (lhs == rhs) @@ -342,6 +369,10 @@ class InferModuleDomains { LogicalResult updateModule(FModuleOp); + /// Build a table of exported domains: a map from domains defined internally, + /// to their aliasing output ports. + void initializeExportTable(FModuleOp); + /// After generalizing the module, all domains should be solved. Reflect the /// solved domain associations into the port domain info attribute. LogicalResult updatePortDomainAssociations(FModuleOp); @@ -352,6 +383,17 @@ class InferModuleDomains { LogicalResult updateOpDomainAssociations(Operation *); LogicalResult updateOpDomainAssociations(InstanceOp); + /// Copy the domain associations from the module domain info attribute into a + /// small vector. + SmallVector copyPortDomainAssociations(ArrayAttr, size_t); + + /// For a concrete domain value, get the unique aliasing port. When a hardware + /// port is associated to a domain, we must ensure that the domain is + /// available as a port of the module. Fails if the domain is not + /// exported by a port, or if the domain is exported by multiple ports. + LogicalResult getExportingPortIndex(FModuleOp, Value, size_t &); + LogicalResult getExportingPort(FModuleOp, Value, BlockArgument &); + /// Add domain ports for any uninferred domains associated to hardware. /// Returns the inserted ports, which will be used later to generalize the /// instances of this module. @@ -360,7 +402,10 @@ class InferModuleDomains { void generalizeInstance(InstanceOp, const PortInsertions &); /// Unify the associated domain rows of two terms. - LogicalResult unifyAssociations(Value, Value); + LogicalResult unifyAssociations(Operation *, Value, Value); + + /// If the domain value is an alias, returns the domain it aliases. + Value getUnderlyingDomain(Value); /// Record a mapping from domain in the IR to its corresponding term. void setTermForDomain(Value, Term *); @@ -378,10 +423,13 @@ class InferModuleDomains { /// Get the associated domain row, forced to be at least a row. RowTerm *getDomainAssociationAsRow(Value); + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, allocate a variable to stand for + /// the row of domains. Term *getDomainAssociation(Value); - /// Get the term which represents the row of domains associated with a - /// hardware value in the design. + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, returns nullptr. Term *getOptDomainAssociation(Value) const; /// Allocate a row, where each domain is a variable. @@ -398,6 +446,12 @@ class InferModuleDomains { /// replace them with a new variable. ArrayRef allocateArray(ArrayRef); + /// Print a term in a user-friendly way. + void render(Diagnostic &, Term *) const; + void render(Diagnostic &, VariableIDTable &, Term *) const; + + void emitPortDomainCrossingError(BlockArgument, size_t, Term *, Term *) const; + /// Information about the domains in a circuit. const CircuitDomainInfo &circuitInfo; @@ -410,8 +464,8 @@ class InferModuleDomains { /// A map from hardware values to their associated row of domains, as a term. DenseMap associationTable; - /// A map from local domain definition to its export, as a port. - DenseMap exportTable; + /// A map from local domain definition to its aliasing output ports. + DenseMap> exportTable; /// A boolean tracking if a non-fatal error occurred, or not. bool ok = true; @@ -427,9 +481,11 @@ InferModuleDomains::InferModuleDomains(const CircuitDomainInfo &circuitInfo) : circuitInfo(circuitInfo) {} LogicalResult InferModuleDomains::operator()(FModuleOp module) { - llvm::errs() << "================================================\n"; - llvm::errs() << "infer module domains: " << module.getModuleName() << "\n"; - llvm::errs() << "================================================\n"; + LLVM_DEBUG( + llvm::errs() << "================================================\n"; + llvm::errs() << "infer module domains: " << module.getModuleName() + << "\n"; + llvm::errs() << "================================================\n";); if (failed(processPorts(module))) return failure(); @@ -437,15 +493,19 @@ LogicalResult InferModuleDomains::operator()(FModuleOp module) { if (failed(processBody(module))) return failure(); - for (auto association : associationTable) { + LLVM_DEBUG(for (auto association : associationTable) { llvm::errs() << "association:\n"; llvm::errs() << " " << association.first << "\n"; llvm::errs() << " " << association.second << "\n"; - } + }); + // PortInsertions insertions; if (failed(updateModule(module))) return failure(); + // if (failed(updateInstances(insertions))) + // return failure(); + return llvm::success(ok); } @@ -456,7 +516,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { // Process module ports - domain ports define explicit domains. DenseMap domainIndexTable; for (size_t i = 0; i < numPorts; ++i) { - Value port = module.getArgument(i); + BlockArgument port = module.getArgument(i); // This is a domain port. if (isa(port.getType())) { @@ -477,7 +537,13 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { for (auto domainPortIndexAttr : portDomains.getAsRange()) { auto domainPortIndex = domainPortIndexAttr.getUInt(); auto domainIndex = domainIndexTable[domainPortIndex]; - auto *term = getTermForDomain(module.getArgument(domainPortIndex)); + auto domainValue = module.getArgument(domainPortIndex); + auto *term = getTermForDomain(domainValue); + auto &slot = elements[domainIndex]; + if (failed(unify(slot, term))) { + emitPortDomainCrossingError(port, domainIndex, slot, term); + return failure(); + } elements[domainIndex] = term; } auto *row = allocateRow(elements); @@ -500,7 +566,7 @@ LogicalResult InferModuleDomains::processBody(FModuleOp module) { } LogicalResult InferModuleDomains::processOp(Operation *op) { - llvm::errs() << "process op: " << *op << "\n"; + LLVM_DEBUG(llvm::errs() << "process op: " << *op << "\n"); if (auto instance = dyn_cast(op)) return processOp(instance); @@ -512,12 +578,12 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { // the same domain associations. Value lhs; for (auto rhs : op->getOperands()) { - if (failed(unifyAssociations(lhs, rhs))) + if (failed(unifyAssociations(op, lhs, rhs))) return failure(); lhs = rhs; } for (auto rhs : op->getResults()) { - if (failed(unifyAssociations(lhs, rhs))) + if (failed(unifyAssociations(op, lhs, rhs))) return failure(); lhs = rhs; } @@ -530,7 +596,7 @@ LogicalResult InferModuleDomains::processOp(InstanceOp op) { for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { Value port = op.getResult(i); - llvm::errs() << "handling instance port: " << port << "\n"; + LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); if (isa(port.getType())) { auto typeID = circuitInfo.getDomainIndex(domainInfo, i); @@ -571,7 +637,7 @@ LogicalResult InferModuleDomains::processOp(InstanceOp op) { LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { auto domains = op.getDomains(); if (domains.empty()) - return unifyAssociations(op.getInput(), op.getResult()); + return unifyAssociations(op, op.getInput(), op.getResult()); auto input = op.getInput(); RowTerm *inputRow = getDomainAssociationAsRow(input); @@ -587,8 +653,9 @@ LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { } LogicalResult InferModuleDomains::updateModule(FModuleOp op) { - auto insertions = generalizeModule(op); + initializeExportTable(op); + auto insertions = generalizeModule(op); if (failed(updatePortDomainAssociations(op))) return failure(); @@ -598,25 +665,40 @@ LogicalResult InferModuleDomains::updateModule(FModuleOp op) { return success(); } +void InferModuleDomains::initializeExportTable(FModuleOp module) { + size_t numPorts = module.getNumPorts(); + auto directions = module.getPortDirections(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + if (!isa(type)) + continue; + auto direction = direction::get(directions[i]); + if (direction == Direction::Out) { + auto value = getUnderlyingDomain(port); + if (value) + exportTable[value].push_back(port); + } + } +} + LogicalResult InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { // At this point, all domain variables mentioned in ports have been // solved by generalizing the module (adding input domain ports). Now, we have // to form the new port domain information for the module by examining the - // solutions to the associated domains of each port. + // the associated domains of each port. auto *context = module.getContext(); - auto builder = OpBuilder::atBlockBegin(module.getBodyBlock()); - auto oldDomainInfo = module.getDomainInfoAttr(); + auto numDomains = circuitInfo.getNumDomains(); + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + auto oldModuleDomainInfo = module.getDomainInfoAttr(); auto numPorts = module.getNumPorts(); - SmallVector domainInfo(numPorts); + SmallVector newModuleDomainInfo(numPorts); for (size_t i = 0; i < numPorts; ++i) { auto port = module.getArgument(i); auto type = port.getType(); - // By default, copy the old domain info over. - domainInfo[i] = oldDomainInfo[i]; - // If the port is an output domain, we may need to drive the output with // a value. If we don't know what value to drive to the port, error. if (isa(type)) { @@ -641,42 +723,111 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { // If the output port is not driven, drive it. if (!driven) { auto loc = port.getLoc(); - auto typeName = getDomainPortTypeName(oldDomainInfo, i); + auto typeName = getDomainPortTypeName(oldModuleDomainInfo, i); auto typeRef = FlatSymbolRefAttr::get(typeName); auto value = val->value; DomainDefineOp::create(builder, loc, port, value, typeRef); } - - // Record the output port as an export of the underlying value. - exportTable.insert({val->value, port}); } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; continue; } if (isa(type)) { - SmallVector associations(circuitInfo.getNumDomains()); + auto associations = copyPortDomainAssociations(oldModuleDomainInfo, i); auto *row = getDomainAssociationAsRow(port); - for (auto [typeID, term] : llvm::enumerate(row->elements)) { - auto *domain = find(term); - auto *val = dyn_cast(domain); - if (!val) - return module.emitError() << "unable to infer domain for port"; - - auto arg = dyn_cast(val->value); - if (!arg) - arg = exportTable.at(val->value); - - auto idx = arg.getArgNumber(); - associations[typeID] = IntegerAttr::get( - IntegerType::get(context, 32, IntegerType::Unsigned), idx); + for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { + if (associations[domainTypeID]) + continue; + auto domain = cast(find(row->elements[domainTypeID]))->value; + size_t domainPortIndex = 0; + if (auto arg = dyn_cast(domain)) { + if (arg.getOwner()->getParentOp() == module) { + domainPortIndex = arg.getArgNumber(); + } + } else { + auto exports = exportTable.lookup(domain); + if (exports.empty()) { + auto diag = + module.emitOpError("Failed to infer domain information"); + diag.attachNote(module.getPortLocation(i)) << "for port # " << i; + diag.attachNote() << "the domain is not exported"; + return failure(); + } + + if (exports.size() > 1) { + auto diag = + module.emitOpError("Failed to infer domain information"); + diag.attachNote(module.getPortLocation(i)) << "for port # " << i; + diag.attachNote() << "cannot choose between aliasing ports"; + for (auto arg : exports) { + diag.attachNote(module.getPortLocation(arg.getArgNumber())) + << "aliased here"; + } + return failure(); + } + } + associations[domainTypeID] = IntegerAttr::get( + IntegerType::get(context, 32, IntegerType::Unsigned), + domainPortIndex); } - domainInfo[i] = ArrayAttr::get(context, associations); + + newModuleDomainInfo[i] = ArrayAttr::get(context, associations); continue; } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; } - auto domainInfoAttr = ArrayAttr::get(module.getContext(), domainInfo); - module.setDomainInfoAttr(domainInfoAttr); + auto newModuleDomainInfoAttr = + ArrayAttr::get(module.getContext(), newModuleDomainInfo); + module.setDomainInfoAttr(newModuleDomainInfoAttr); + return success(); +} + +SmallVector +InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, + size_t portIndex) { + SmallVector result(circuitInfo.getNumDomains()); + auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); + for (auto domainPortIndexAttr : oldAssociations.getAsRange()) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = + circuitInfo.getDomainIndex(moduleDomainInfo, domainPortIndex); + result[domainTypeID] = domainPortIndexAttr; + }; + return result; +} + +LogicalResult InferModuleDomains::getExportingPortIndex(FModuleOp module, + Value value, + size_t &result) { + BlockArgument arg; + if (failed(getExportingPort(module, value, arg))) + return failure(); + + result = arg.getArgNumber(); + return success(); +} + +LogicalResult InferModuleDomains::getExportingPort(FModuleOp module, + Value value, + BlockArgument &result) { + if (auto arg = dyn_cast(value)) { + if (arg.getOwner()->getParentOp() == module) { + result = arg; + return success(); + } + } + + auto exports = exportTable.lookup(value); + assert(!exports.empty()); + + if (exports.size() > 1) + return failure(); + + result = exports[0]; return success(); } @@ -702,23 +853,20 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { auto *row = getDomainAssociationAsRow(port); for (auto [typeID, term] : llvm::enumerate(row->elements)) { auto *domain = find(term); + if (auto *val = dyn_cast(domain)) { + auto value = val->value; // If the domain value is defined inside the module body, we must output // export the domain, so it may appear in the signature of the // module. - auto result = dyn_cast(val->value); - if (!result) + if (isPort(module, value)) continue; - // The domain is defined internally. If there is already an aliasing - // port, we are done. - if (exportTable.contains(val->value)) + // The domain is defined internally. If there value is already exported, + // or will be exported, we are done. + if (exportTable.contains(value) || pendingExports.contains(value)) continue; - // If there is already a pending export, we are also done. - if (pendingExports.contains(val->value)) - continue; - // We must insert a new output domain port. auto domainDecl = circuitInfo.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); @@ -772,9 +920,6 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { // Put the ports in place. module.insertPorts(insertions); - llvm::errs() << "generalization complete\n"; - llvm::errs() << module << "\n"; - // Solve the variables. for (auto [var, portIndex] : pendingSolutions) { auto *solution = allocate(module.getArgument(portIndex)); @@ -790,7 +935,7 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { auto typeName = getDomainPortTypeName(domainInfo, portIndex); auto typeNameRef = FlatSymbolRefAttr::get(typeName); DomainDefineOp::create(builder, port.getLoc(), port, value, typeNameRef); - exportTable[value] = port; + exportTable[value].push_back(port); setTermForDomain(port, allocate(value)); } @@ -856,10 +1001,11 @@ LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { return success(); } -LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { - llvm::errs() << " unify associations of:\n"; - llvm::errs() << " lhs=" << lhs << "\n"; - llvm::errs() << " rhs=" << rhs << "\n"; +LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, + Value rhs) { + LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; + llvm::errs() << " lhs=" << lhs << "\n"; + llvm::errs() << " rhs=" << rhs << "\n";); if (!lhs || !rhs) return success(); @@ -872,7 +1018,20 @@ LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { if (lhsTerm) { if (rhsTerm) { - return unify(lhsTerm, rhsTerm); + if (failed(unify(lhsTerm, rhsTerm))) { + auto diag = op->emitOpError("illegal domain crossing in operation"); + auto ¬e1 = diag.attachNote(lhs.getLoc()); + + note1 << "1st operand has domains: "; + VariableIDTable idTable; + render(note1, idTable, lhsTerm); + + auto ¬e2 = diag.attachNote(rhs.getLoc()); + note2 << "2nd operand has domains: "; + render(note2, idTable, rhsTerm); + + return failure(); + } } setDomainAssociation(rhs, lhsTerm); return success(); @@ -889,6 +1048,14 @@ LogicalResult InferModuleDomains::unifyAssociations(Value lhs, Value rhs) { return success(); } +Value InferModuleDomains::getUnderlyingDomain(Value value) { + assert(isa(value.getType())); + auto *term = getOptTermForDomain(value); + if (auto *val = llvm::dyn_cast_if_present(term)) + return val->value; + return nullptr; +} + Term *InferModuleDomains::getTermForDomain(Value value) { assert(isa(value.getType())); if (auto *term = getOptTermForDomain(value)) @@ -958,8 +1125,8 @@ void InferModuleDomains::setDomainAssociation(Value value, Term *term) { assert(term); term = find(term); associationTable.insert({value, term}); - llvm::errs() << " set domain association: " << value << " -> " << term - << "\n"; + LLVM_DEBUG(llvm::errs() << " set domain association: " << value; + llvm::errs() << " -> " << term << "\n";); } RowTerm *InferModuleDomains::allocateRow() { @@ -993,6 +1160,65 @@ ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { return ArrayRef(result, elements.size()); } +void InferModuleDomains::render(Diagnostic &out, Term *term) const { + VariableIDTable idTable; + render(out, idTable, term); +} + +// NOLINTNEXTLINE(misc-no-recursion) +void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, + Term *term) const { + term = find(term); + if (auto *var = dyn_cast(term)) { + out << "?" << idTable.get(var); + return; + } + if (auto *val = dyn_cast(term)) { + auto value = val->value; + auto [name, rooted] = getFieldName(FieldRef(value, 0), false); + out << name; + // out.attachNote(val->value.getLoc()) << name << " defined here"; + return; + } + if (auto *row = dyn_cast(term)) { + bool first = true; + out << "["; + for (size_t i = 0, e = circuitInfo.getNumDomains(); i < e; ++i) { + auto domainOp = circuitInfo.getDomain(i); + if (!first) { + out << ", "; + first = false; + } + out << domainOp.getName() << ": "; + render(out, idTable, row->elements[i]); + } + out << "]"; + return; + } +} + +void InferModuleDomains::emitPortDomainCrossingError(BlockArgument port, + size_t domainTypeID, + Term *term1, + Term *term2) const { + VariableIDTable idTable; + + auto portIndex = port.getArgNumber(); + auto domainType = circuitInfo.getDomain(domainTypeID); + auto domainTypeName = domainType.getName(); + + auto diag = emitError(port.getLoc()); + diag << "illegal " << domainTypeName << " crossing in port #" << portIndex; + + auto ¬e1 = diag.attachNote(); + note1 << "1st instance: "; + render(note1, term1); + + auto ¬e2 = diag.attachNote(); + note2 << "2nd instance: "; + render(note2, term2); +} + //===--------------------------------------------------------------------------- // InferDomainsPass: Top-level pass implementation. //===--------------------------------------------------------------------------- diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index aa7aa2bdd773..74a70cfd0714 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -1,63 +1,50 @@ // RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --verify-diagnostics --split-input-file -// Test case 1: Illegal domain crossing - both matchingconnect and connect should fail -firrtl.circuit "IllegalDomainCrossing" { +// Port annotated with same domain type twice. +firrtl.circuit "DomainCrossOnPort" { firrtl.domain @ClockDomain {} - firrtl.module @IllegalDomainCrossing( - // expected-note@below {{operand is in domain defined here}} + firrtl.module @DomainCrossOnPort( in %A: !firrtl.domain of @ClockDomain, - // expected-note@below {{first operand is in domain defined here}} in %B: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A], - out %b: !firrtl.uint<1> domains [%B] - ) { - // expected-error @below {{illegal domain crossing in operation}} - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - - // expected-error @below {{illegal domain crossing in operation}} - firrtl.connect %b, %a : !firrtl.uint<1>, !firrtl.uint<1> - } + // expected-error @below {{illegal ClockDomain crossing in port #2}} + // expected-note @below {{1st instance: A}} + // expected-note @below {{2nd instance: B}} + in %p: !firrtl.uint<1> domains [%A, %B] + ) {} } // ----- -// Test case 2: Multiple domain crossings -firrtl.circuit "MultipleDomainCrossings" { +// Illegal domain crossing - connect op. +firrtl.circuit "IllegalDomainCrossing" { firrtl.domain @ClockDomain {} - firrtl.module @MultipleDomainCrossings( - // expected-note@below {{operand is in domain defined here}} + firrtl.module @IllegalDomainCrossing( in %A: !firrtl.domain of @ClockDomain, - // expected-note@below {{first operand is in domain defined here}} in %B: !firrtl.domain of @ClockDomain, - // expected-note@below {{first operand is in domain defined here}} - in %C: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} in %a: !firrtl.uint<1> domains [%A], - out %b: !firrtl.uint<1> domains [%B], - out %c: !firrtl.uint<1> domains [%C] + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] ) { - // expected-error@below {{illegal domain crossing in operation}} - firrtl.matchingconnect %b, %a : !firrtl.uint<1> - - // expected-error@below {{illegal domain crossing in operation}} - firrtl.matchingconnect %c, %a : !firrtl.uint<1> + // expected-error @below {{illegal domain crossing in operation}} + firrtl.connect %b, %a : !firrtl.uint<1> } } // ----- -// Test case 3: Domain sequence mismatch - different lengths -firrtl.circuit "SequenceLengthMismatch" { +// Illegal domain crossing at matchingconnect op. +firrtl.circuit "IllegalDomainCrossing" { firrtl.domain @ClockDomain {} - firrtl.module @SequenceLengthMismatch( - // expected-note@below {{operand (domain 1 of 2) is in domain defined here}} + firrtl.module @IllegalDomainCrossing( in %A: !firrtl.domain of @ClockDomain, - // expected-note@below {{first operand is in domain defined here}} - // expected-note@below {{operand (domain 2 of 2) is in domain defined here}} in %B: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A, %B], + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} out %b: !firrtl.uint<1> domains [%B] ) { - // expected-error@below {{illegal domain crossing in operation}} + // expected-error @below {{illegal domain crossing in operation}} firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index c065c3aa7bf8..0894a0425917 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -1,11 +1,11 @@ // RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --split-input-file | FileCheck %s -// Test case 1: Legal domain usage - no crossing +// Legal domain usage - no crossing. firrtl.circuit "LegalDomains" { firrtl.domain @ClockDomain {} firrtl.module @LegalDomains( - in %A: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A], + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], out %b: !firrtl.uint<1> domains [%A] ) { // Connecting within the same domain is legal. @@ -16,7 +16,7 @@ firrtl.circuit "LegalDomains" { // ----- -// Test case 2: Domain inference through connections +// Domain inference through connections. firrtl.circuit "DomainInference" { firrtl.domain @ClockDomain {} firrtl.module @DomainInference( @@ -38,7 +38,7 @@ firrtl.circuit "DomainInference" { // ----- -// Test case 3: Unsafe domain cast +// Unsafe domain cast firrtl.circuit "UnsafeDomainCast" { firrtl.domain @ClockDomain {} firrtl.module @UnsafeDomainCast( @@ -58,28 +58,29 @@ firrtl.circuit "UnsafeDomainCast" { // ----- -// Test case 4: Domain sequence matching - legal case +// Domain sequence matching. firrtl.circuit "LegalSequences" { firrtl.domain @ClockDomain {} + firrtl.domain @PowerDomain {} firrtl.module @LegalSequences( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A, %B], - out %b: !firrtl.uint<1> domains [%A, %B] + in %C: !firrtl.domain of @ClockDomain, + in %P: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%C, %P], + out %b: !firrtl.uint<1> domains [%C, %P] ) { firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "LegalSequences" // ----- -// Test case 5: Domain sequence order equivalence - should be legal +// Domain sequence order equivalence - should be legal firrtl.circuit "SequenceOrderEquivalence" { firrtl.domain @ClockDomain {} + firrtl.domain @PowerDomain {} firrtl.module @SequenceOrderEquivalence( in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, in %a: !firrtl.uint<1> domains [%A, %B], out %b: !firrtl.uint<1> domains [%B, %A] ) { @@ -91,12 +92,13 @@ firrtl.circuit "SequenceOrderEquivalence" { // ----- -// Test case 6: Domain sequence inference +// Domain sequence inference firrtl.circuit "SequenceInference" { firrtl.domain @ClockDomain {} + firrtl.domain @PowerDomain {} firrtl.module @SequenceInference( in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, in %a: !firrtl.uint<1> domains [%A, %B], out %d: !firrtl.uint<1> ) { @@ -109,12 +111,10 @@ firrtl.circuit "SequenceInference" { firrtl.matchingconnect %d, %c : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "SequenceInference" -// CHECK: out %d: !firrtl.uint<1> domains [%A, %B] // ----- -// Test case 7: Domain duplicate equivalence - should be legal +// Domain duplicate equivalence - should be legal. firrtl.circuit "DuplicateDomainEquivalence" { firrtl.domain @ClockDomain {} firrtl.module @DuplicateDomainEquivalence( @@ -122,59 +122,33 @@ firrtl.circuit "DuplicateDomainEquivalence" { in %a: !firrtl.uint<1> domains [%A, %A], out %b: !firrtl.uint<1> domains [%A] ) { - // This should be legal since duplicate domains are canonicalized + // This should be legal since duplicate domains are canonicalized. firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "DuplicateDomainEquivalence" // ----- -// Test case 8: Unsafe domain cast with sequences +// Unsafe domain cast with sequences firrtl.circuit "UnsafeSequenceCast" { firrtl.domain @ClockDomain {} - firrtl.module @UnsafeSequenceCast( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - in %C: !firrtl.domain of @ClockDomain, - in %a: !firrtl.uint<1> domains [%A, %B], - out %c: !firrtl.uint<1> domains [%C] - ) { - %0 = firrtl.unsafe_domain_cast %a domains %C : !firrtl.uint<1> - firrtl.matchingconnect %c, %0 : !firrtl.uint<1> - } -} -// CHECK-LABEL: firrtl.circuit "UnsafeSequenceCast" -// CHECK: out %c: !firrtl.uint<1> domains [%C] + firrtl.domain @PowerDomain {} -// ----- - -// Test case 9: Multiple port domain inference -firrtl.circuit "MultiplePortInference" { - firrtl.domain @ClockDomain {} - firrtl.module @MultiplePortInference( - in %A: !firrtl.domain of @ClockDomain, - in %B: !firrtl.domain of @ClockDomain, - in %inputA: !firrtl.uint<1> domains [%A], - in %inputB: !firrtl.uint<1> domains [%B], - in %inputAB: !firrtl.uint<1> domains [%A, %B], - out %outputA: !firrtl.uint<1>, - out %outputB: !firrtl.uint<1>, - out %outputAB: !firrtl.uint<1> + firrtl.module @UnsafeSequenceCast( + in %C1: !firrtl.domain of @ClockDomain, + in %C2: !firrtl.domain of @ClockDomain, + in %P1: !firrtl.domain of @PowerDomain, + in %i: !firrtl.uint<1> domains [%C1, %P1], + out %o: !firrtl.uint<1> domains [%C2, %P1] ) { - firrtl.matchingconnect %outputA, %inputA : !firrtl.uint<1> - firrtl.matchingconnect %outputB, %inputB : !firrtl.uint<1> - firrtl.matchingconnect %outputAB, %inputAB : !firrtl.uint<1> + %0 = firrtl.unsafe_domain_cast %i domains %C2 : !firrtl.uint<1> + firrtl.matchingconnect %o, %0 : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "MultiplePortInference" -// CHECK: out %outputA: !firrtl.uint<1> domains [%A] -// CHECK: out %outputB: !firrtl.uint<1> domains [%B] -// CHECK: out %outputAB: !firrtl.uint<1> domains [%A, %B] // ----- -// Test case 10: Different port types domain inference +// Different port types domain inference. firrtl.circuit "DifferentPortTypes" { firrtl.domain @ClockDomain {} firrtl.module @DifferentPortTypes( @@ -188,18 +162,18 @@ firrtl.circuit "DifferentPortTypes" { firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> } } -// CHECK-LABEL: firrtl.circuit "DifferentPortTypes" -// CHECK: out %uint_output: !firrtl.uint<8> domains [%A] -// CHECK: out %sint_output: !firrtl.sint<4> domains [%A] // ----- -// Test case 11: Domain inference through wires +// Domain inference through wires. + +// CHECK-LABEL: DomainInferenceThroughWires firrtl.circuit "DomainInferenceThroughWires" { firrtl.domain @ClockDomain {} firrtl.module @DomainInferenceThroughWires( in %A: !firrtl.domain of @ClockDomain, in %input: !firrtl.uint<1> domains [%A], + // CHECK: out %output: !firrtl.uint<1> domains [%A] out %output: !firrtl.uint<1> ) { %wire1 = firrtl.wire : !firrtl.uint<1> @@ -210,18 +184,18 @@ firrtl.circuit "DomainInferenceThroughWires" { firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "DomainInferenceThroughWires" -// CHECK: out %output: !firrtl.uint<1> domains [%A] // ----- -// Test case 12: Register inference +// Register inference/ firrtl.circuit "RegisterInference" { firrtl.domain @ClockDomain {} firrtl.module @RegisterInference( - in %A: !firrtl.domain of @ClockDomain, - in %clock: !firrtl.clock domains [%A], - in %d: !firrtl.uint<1>, + in %A: !firrtl.domain of @ClockDomain, + in %clock: !firrtl.clock domains [%A], + // CHECK: in %d: !firrtl.uint<1> domains [%A] + in %d: !firrtl.uint<1>, + // CHECK: out %q: !firrtl.uint<1> domains [%A] out %q: !firrtl.uint<1> ) { %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> @@ -229,6 +203,3 @@ firrtl.circuit "RegisterInference" { firrtl.matchingconnect %q, %r : !firrtl.uint<1> } } -// CHECK-LABEL: firrtl.circuit "RegisterInference" -// CHECK: in %d: !firrtl.uint<1> domains [%A] -// CHECK: out %q: !firrtl.uint<1> domains [%A] From 5dc882bfd1461ae9da2955aa040bd232efa5df4d Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 1 Oct 2025 16:41:43 -0400 Subject: [PATCH 09/20] More stuff --- lib/Dialect/FIRRTL/Transforms/InferDomains.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index bffdbbb3f151..d31a5ca54155 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -366,6 +366,7 @@ class InferModuleDomains { LogicalResult processOp(Operation *); LogicalResult processOp(InstanceOp); LogicalResult processOp(UnsafeDomainCastOp); + LogicalResult processOp(DomainDefineOp); LogicalResult updateModule(FModuleOp); @@ -572,17 +573,23 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { return processOp(instance); if (auto cast = dyn_cast(op)) return processOp(cast); + if (auto def = dyn_cast(op)) + return processOp(def); // For all operations (including connections), propagate domains from operands // to results This is a conservative approach - all operands and results share // the same domain associations. Value lhs; for (auto rhs : op->getOperands()) { + if (!isa(rhs.getType())) + continue; if (failed(unifyAssociations(op, lhs, rhs))) return failure(); lhs = rhs; } for (auto rhs : op->getResults()) { + if (!isa(rhs.getType())) + continue; if (failed(unifyAssociations(op, lhs, rhs))) return failure(); lhs = rhs; @@ -652,6 +659,14 @@ LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { return success(); } +LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { + auto src = op.getSrc(); + auto dst = op.getDest(); + auto *term = getTermForDomain(src); + setTermForDomain(dst, term); + return success(); +} + LogicalResult InferModuleDomains::updateModule(FModuleOp op) { initializeExportTable(op); From 05a4b64979d7400d52ed4115adcbc7a4abcd182a Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 9 Oct 2025 17:42:39 -0400 Subject: [PATCH 10/20] Check in --- lib/Dialect/FIRRTL/FIRRTLOps.cpp | 28 ------------------- .../FIRRTL/Transforms/InferDomains.cpp | 16 ++--------- 2 files changed, 3 insertions(+), 41 deletions(-) diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index ce699a5b9b01..d772e391e876 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -4343,34 +4343,6 @@ LogicalResult DomainDefineOp::verify() { return success(); } -LogicalResult DomainDefineOp::verify() { - if (failed(checkConnectFlow(*this))) - return failure(); - - for (auto *user : getDest().getUsers()) { - auto connection = dyn_cast(user); - if (!connection || connection == *this || connection.getDest() != getDest()) - continue; - return emitError("destination domains cannot be reused by multiple " - "operations, it can only capture a unique dataflow"); - } - - return success(); -} - -LogicalResult DomainDefineOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // if (failed(verifyPortSymbolUses(*this, symbolTable))) - // return failure(); - - // auto circuitOp = getOperation()->getParentOfType(); - // for (auto layer : getLayers()) { - // if (!symbolTable.lookupSymbolIn(circuitOp, cast(layer))) - // return emitOpError() << "enables undefined layer '" << layer << "'"; - // } - - return success(); -} - void WhenOp::createElseRegion() { assert(!hasElseRegion() && "already has an else region"); getElseRegion().push_back(new Block()); diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index d31a5ca54155..4b9488253c28 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -738,10 +738,8 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { // If the output port is not driven, drive it. if (!driven) { auto loc = port.getLoc(); - auto typeName = getDomainPortTypeName(oldModuleDomainInfo, i); - auto typeRef = FlatSymbolRefAttr::get(typeName); auto value = val->value; - DomainDefineOp::create(builder, loc, port, value, typeRef); + DomainDefineOp::create(builder, loc, port, value); } } @@ -942,14 +940,10 @@ PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { } // Drive the exports. - auto domainInfo = module.getDomainInfoAttr(); auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); - for (auto [value, portIndex] : pendingExports) { auto port = module.getArgument(portIndex); - auto typeName = getDomainPortTypeName(domainInfo, portIndex); - auto typeNameRef = FlatSymbolRefAttr::get(typeName); - DomainDefineOp::create(builder, port.getLoc(), port, value, typeNameRef); + DomainDefineOp::create(builder, port.getLoc(), port, value); exportTable[value].push_back(port); setTermForDomain(port, allocate(value)); } @@ -980,7 +974,6 @@ LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { auto *context = op.getContext(); OpBuilder builder(context); builder.setInsertionPointAfter(op); - auto domainInfo = op.getDomainInfoAttr(); auto numPorts = op->getNumResults(); for (size_t i = 0; i < numPorts; ++i) { auto port = op.getResult(i); @@ -1002,10 +995,8 @@ LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { term = find(term); if (auto *val = dyn_cast(term)) { auto loc = port.getLoc(); - auto typeName = getDomainPortTypeName(domainInfo, i); - auto typeRef = FlatSymbolRefAttr::get(typeName); auto value = val->value; - DomainDefineOp::create(builder, loc, port, value, typeRef); + DomainDefineOp::create(builder, loc, port, value); } else { return op.emitError() << "unable to infer input domain value"; } @@ -1192,7 +1183,6 @@ void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, auto value = val->value; auto [name, rooted] = getFieldName(FieldRef(value, 0), false); out << name; - // out.attachNote(val->value.getLoc()) << name << " defined here"; return; } if (auto *row = dyn_cast(term)) { From 0be6951502be81e364c917688c54d822616dcf15 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 10 Oct 2025 09:19:32 -0400 Subject: [PATCH 11/20] Rename DomainIndex to DomainTypeID --- .../FIRRTL/Transforms/InferDomains.cpp | 94 +++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 4b9488253c28..5543755198d1 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -86,10 +86,10 @@ static bool isPort(FModuleOp module, Value value) { /// Each declared domain in the circuit is assigned an index, based on the order /// in which it appears. Domain associations for hardware values are represented /// as a list of domains, sorted by the index of the domain type. -using DomainIndex = size_t; +using DomainTypeID = size_t; /// Information about the domains in the circuit. Able to map domains to their -/// domain-index, which in this pass is the canonical way to reference the type +/// type ID, which in this pass is the canonical way to reference the type /// of a domain. namespace { struct CircuitDomainInfo { @@ -101,22 +101,22 @@ struct CircuitDomainInfo { ArrayRef getDomains() const { return domainTable; } size_t getNumDomains() const { return domainTable.size(); } - DomainOp getDomain(DomainIndex id) const { return domainTable[id]; } + DomainOp getDomain(DomainTypeID id) const { return domainTable[id]; } - DomainIndex getDomainIndex(DomainOp op) const { - return indexTable.at(op.getNameAttr()); + DomainTypeID getDomainTypeID(DomainOp op) const { + return typeIDTable.at(op.getNameAttr()); } - DomainIndex getDomainIndex(StringAttr name) const { - return indexTable.at(name); + DomainTypeID getDomainTypeID(StringAttr name) const { + return typeIDTable.at(name); } - DomainIndex getDomainIndex(FlatSymbolRefAttr ref) const { - return getDomainIndex(ref.getAttr()); + DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const { + return getDomainTypeID(ref.getAttr()); } - DomainIndex getDomainIndex(ArrayAttr info, size_t i) const { + DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const { auto name = getDomainPortTypeName(info, i); - return getDomainIndex(name); + return getDomainTypeID(name); } - DomainIndex getDomainIndex(Value value) const { + DomainTypeID getDomainTypeID(Value value) const { assert(isa(value.getType())); if (auto arg = dyn_cast(value)) { auto *block = arg.getOwner(); @@ -124,7 +124,7 @@ struct CircuitDomainInfo { auto module = cast(owner); auto info = module.getDomainInfoAttr(); auto i = arg.getArgNumber(); - return getDomainIndex(info, i); + return getDomainTypeID(info, i); } auto result = dyn_cast(value); @@ -132,29 +132,27 @@ struct CircuitDomainInfo { auto instance = cast(owner); auto info = instance.getDomainInfoAttr(); auto i = result.getResultNumber(); - return getDomainIndex(info, i); + return getDomainTypeID(info, i); } - void clear() { - domainTable.clear(); - indexTable.clear(); +private: + void processDomain(DomainOp op) { + auto index = domainTable.size(); + auto name = op.getNameAttr(); + domainTable.push_back(op); + typeIDTable.insert({name, index}); } void processCircuit(CircuitOp circuit) { - clear(); for (auto decl : circuit.getOps()) processDomain(decl); } - void processDomain(DomainOp op) { - auto index = domainTable.size(); - auto name = op.getNameAttr(); - domainTable.push_back(op); - indexTable.insert({name, index}); - } - + /// A map from domain type ID to op. SmallVector domainTable; - DenseMap indexTable; + + /// A map from domain name to type ID. + DenseMap typeIDTable; }; } // namespace @@ -515,14 +513,14 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { auto numPorts = module.getNumPorts(); // Process module ports - domain ports define explicit domains. - DenseMap domainIndexTable; + DenseMap domainTypeIDTable; for (size_t i = 0; i < numPorts; ++i) { BlockArgument port = module.getArgument(i); // This is a domain port. if (isa(port.getType())) { - auto index = circuitInfo.getDomainIndex(portDomainInfo, i); - domainIndexTable[i] = index; + auto typeID = circuitInfo.getDomainTypeID(portDomainInfo, i); + domainTypeIDTable[i] = typeID; if (module.getPortDirection(i) == Direction::In) { setTermForDomain(port, allocate(port)); } @@ -537,15 +535,15 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { SmallVector elements(circuitInfo.getNumDomains()); for (auto domainPortIndexAttr : portDomains.getAsRange()) { auto domainPortIndex = domainPortIndexAttr.getUInt(); - auto domainIndex = domainIndexTable[domainPortIndex]; + auto domainTypeID = domainTypeIDTable[domainPortIndex]; auto domainValue = module.getArgument(domainPortIndex); auto *term = getTermForDomain(domainValue); - auto &slot = elements[domainIndex]; + auto &slot = elements[domainTypeID]; if (failed(unify(slot, term))) { - emitPortDomainCrossingError(port, domainIndex, slot, term); + emitPortDomainCrossingError(port, domainTypeID, slot, term); return failure(); } - elements[domainIndex] = term; + elements[domainTypeID] = term; } auto *row = allocateRow(elements); setDomainAssociation(port, row); @@ -576,9 +574,9 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { if (auto def = dyn_cast(op)) return processOp(def); - // For all operations (including connections), propagate domains from operands - // to results This is a conservative approach - all operands and results share - // the same domain associations. + // For all other operations (including connections), propagate domains from + // operands to results This is a conservative approach - all operands and + // results share the same domain associations. Value lhs; for (auto rhs : op->getOperands()) { if (!isa(rhs.getType())) @@ -606,7 +604,7 @@ LogicalResult InferModuleDomains::processOp(InstanceOp op) { LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); if (isa(port.getType())) { - auto typeID = circuitInfo.getDomainIndex(domainInfo, i); + auto typeID = circuitInfo.getDomainTypeID(domainInfo, i); domainPortTypeIDTable[i] = typeID; if (op.getPortDirection(i) == Direction::Out) { setTermForDomain(port, allocate(port)); @@ -650,7 +648,7 @@ LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { RowTerm *inputRow = getDomainAssociationAsRow(input); SmallVector elements(inputRow->elements); for (auto domain : op.getDomains()) { - auto index = circuitInfo.getDomainIndex(domain); + auto index = circuitInfo.getDomainTypeID(domain); elements[index] = getTermForDomain(domain); } @@ -750,10 +748,10 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { if (isa(type)) { auto associations = copyPortDomainAssociations(oldModuleDomainInfo, i); auto *row = getDomainAssociationAsRow(port); - for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { - if (associations[domainTypeID]) + for (size_t DomainTypeID = 0; DomainTypeID < numDomains; ++DomainTypeID) { + if (associations[DomainTypeID]) continue; - auto domain = cast(find(row->elements[domainTypeID]))->value; + auto domain = cast(find(row->elements[DomainTypeID]))->value; size_t domainPortIndex = 0; if (auto arg = dyn_cast(domain)) { if (arg.getOwner()->getParentOp() == module) { @@ -781,7 +779,7 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { return failure(); } } - associations[domainTypeID] = IntegerAttr::get( + associations[DomainTypeID] = IntegerAttr::get( IntegerType::get(context, 32, IntegerType::Unsigned), domainPortIndex); } @@ -806,9 +804,9 @@ InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); for (auto domainPortIndexAttr : oldAssociations.getAsRange()) { auto domainPortIndex = domainPortIndexAttr.getUInt(); - auto domainTypeID = - circuitInfo.getDomainIndex(moduleDomainInfo, domainPortIndex); - result[domainTypeID] = domainPortIndexAttr; + auto DomainTypeID = + circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); + result[DomainTypeID] = domainPortIndexAttr; }; return result; } @@ -1209,11 +1207,11 @@ void InferModuleDomains::emitPortDomainCrossingError(BlockArgument port, VariableIDTable idTable; auto portIndex = port.getArgNumber(); - auto domainType = circuitInfo.getDomain(domainTypeID); - auto domainTypeName = domainType.getName(); + auto domainOp = circuitInfo.getDomain(domainTypeID); + auto domainName = domainOp.getName(); auto diag = emitError(port.getLoc()); - diag << "illegal " << domainTypeName << " crossing in port #" << portIndex; + diag << "illegal " << domainName << " crossing in port #" << portIndex; auto ¬e1 = diag.attachNote(); note1 << "1st instance: "; From 7795379e331a0abfc524e98d0e262bf5e9cd610b Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 10 Oct 2025 17:05:08 -0400 Subject: [PATCH 12/20] WIP --- .../circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h | 22 +++ include/circt/Support/InstanceGraph.h | 7 + .../FIRRTL/Transforms/InferDomains.cpp | 134 +++++++++++++++--- 3 files changed, 142 insertions(+), 21 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h index e8339dd7320b..44a9903b5520 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h @@ -85,6 +85,28 @@ struct PortInfo { annotations(annos), domains(domains) {} }; +inline bool operator==(const PortInfo &lhs, const PortInfo &rhs) { + if (lhs.name != rhs.name) + return false; + if (lhs.type != rhs.type) + return false; + if (lhs.direction != rhs.direction) + return false; + if (lhs.sym != rhs.sym) + return false; + if (lhs.loc != rhs.loc) + return false; + if (lhs.annotations != rhs.annotations) + return false; + if (lhs.domains != rhs.domains) + return false; + return true; +} + +inline bool operator!=(const PortInfo &lhs, const PortInfo &rhs) { + return !(lhs == rhs); +} + enum class ConnectBehaviorKind { /// Classic FIRRTL connections: last connect 'wins' across paths; /// conditionally applied under 'when'. diff --git a/include/circt/Support/InstanceGraph.h b/include/circt/Support/InstanceGraph.h index a26cf5ddce77..70fdcf2e67f9 100644 --- a/include/circt/Support/InstanceGraph.h +++ b/include/circt/Support/InstanceGraph.h @@ -59,6 +59,11 @@ class InstanceGraphNode; class InstanceRecord : public llvm::ilist_node_with_parent { public: + /// Get the op that this is tracking. + Operation *getOperation() { + return instance.getOperation(); + } + /// Get the instance-like op that this is tracking. template auto getInstance() { @@ -113,6 +118,8 @@ class InstanceGraphNode : public llvm::ilist_node { public: InstanceGraphNode() : module(nullptr) {} + Operation *getOperation() { return module.getOperation(); } + /// Get the module that this node is tracking. template auto getModule() { diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 5543755198d1..e8e5e9ca44eb 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -156,6 +156,59 @@ struct CircuitDomainInfo { }; } // namespace +//====-------------------------------------------------------------------------- +// ModuleUpdateInfo: Summary of about how a module's interface has changed. +//====-------------------------------------------------------------------------- + +namespace { +struct ModuleUpdateInfo { + ArrayAttr portDomainInfo; + PortInsertions portInsertions; +}; +} // namespace + +static bool operator==(const ModuleUpdateInfo &lhs, + const ModuleUpdateInfo &rhs) { + if (lhs.portDomainInfo != rhs.portDomainInfo) + return false; + + if (lhs.portInsertions.size() != rhs.portInsertions.size()) + return false; + + for (size_t i = 0, e = lhs.portInsertions.size(); i < e; ++i) { + if (lhs.portInsertions != rhs.portInsertions) + return false; + } + + return true; +} + +static void updateInstance(OpBuilder &builder, const ModuleUpdateInfo &info, + InstanceOp op) { + auto clone = op.insertPorts(info.portInsertions); + clone.setDomainInfoAttr(info.portDomainInfo); +} + +static void updateInstance(OpBuilder &builder, const ModuleUpdateInfo &info, + InstanceChoiceOp op) { + auto clone = op.insertPorts(info.portInsertions); + clone.setDomainInfoAttr(info.portDomainInfo); +} + +using ModuleUpdateInfoTable = DenseMap; + +LogicalResult updateInstance(const ModuleUpdateInfoTable &table, + InstanceOp op) { + return success(); +} + +LogicalResult updateInstance(const ModuleUpdateInfoTable &table, + InstanceChoiceOp op) { + // verify that all modules have the same update. + auto _mod = op.getDefaultTargetAttr(); + return success(); +} + //====-------------------------------------------------------------------------- // Terms: Syntax for unifying domain and domain-rows. //====-------------------------------------------------------------------------- @@ -343,14 +396,14 @@ namespace { class InferModuleDomains { public: /// Run infer-domains on a module. - static LogicalResult run(const CircuitDomainInfo &, FModuleOp); + static LogicalResult run(const CircuitDomainInfo &, InstanceGraphNode *node); private: /// Initialize module-level state. InferModuleDomains(const CircuitDomainInfo &); /// Execute on the given module. - LogicalResult operator()(FModuleOp); + LogicalResult operator()(InstanceGraphNode *node); /// Record the domain associations of hardware ports, and record the /// underlying value of output domain ports. @@ -366,10 +419,10 @@ class InferModuleDomains { LogicalResult processOp(UnsafeDomainCastOp); LogicalResult processOp(DomainDefineOp); - LogicalResult updateModule(FModuleOp); + LogicalResult updateModule(FModuleOp, PortInsertions &); /// Build a table of exported domains: a map from domains defined internally, - /// to their aliasing output ports. + /// to their set of aliasing output ports. void initializeExportTable(FModuleOp); /// After generalizing the module, all domains should be solved. Reflect the @@ -396,7 +449,10 @@ class InferModuleDomains { /// Add domain ports for any uninferred domains associated to hardware. /// Returns the inserted ports, which will be used later to generalize the /// instances of this module. - PortInsertions generalizeModule(FModuleOp); + PortInsertions generalizeModule(FModuleOp, PortInsertions &); + + LogicalResult updateInstances(InstanceGraphNode *node, + const PortInsertions &); void generalizeInstance(InstanceOp, const PortInsertions &); @@ -472,14 +528,18 @@ class InferModuleDomains { } // namespace LogicalResult InferModuleDomains::run(const CircuitDomainInfo &circuitInfo, - FModuleOp module) { - return InferModuleDomains(circuitInfo)(module); + InstanceGraphNode *node) { + return InferModuleDomains(circuitInfo)(node); } InferModuleDomains::InferModuleDomains(const CircuitDomainInfo &circuitInfo) : circuitInfo(circuitInfo) {} -LogicalResult InferModuleDomains::operator()(FModuleOp module) { +LogicalResult InferModuleDomains::operator()(InstanceGraphNode *node) { + auto module = dyn_cast(node->getOperation()); + if (!module) + return success(); + LLVM_DEBUG( llvm::errs() << "================================================\n"; llvm::errs() << "infer module domains: " << module.getModuleName() @@ -498,12 +558,12 @@ LogicalResult InferModuleDomains::operator()(FModuleOp module) { llvm::errs() << " " << association.second << "\n"; }); - // PortInsertions insertions; - if (failed(updateModule(module))) + PortInsertions insertions; + if (failed(updateModule(module, insertions))) return failure(); - // if (failed(updateInstances(insertions))) - // return failure(); + if (failed(updateInstances(node, insertions))) + return failure(); return llvm::success(ok); } @@ -665,10 +725,11 @@ LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { return success(); } -LogicalResult InferModuleDomains::updateModule(FModuleOp op) { +LogicalResult InferModuleDomains::updateModule(FModuleOp op, + PortInsertions &insertions) { initializeExportTable(op); - auto insertions = generalizeModule(op); + generalizeModule(op, insertions); if (failed(updatePortDomainAssociations(op))) return failure(); @@ -842,13 +903,14 @@ LogicalResult InferModuleDomains::getExportingPort(FModuleOp module, return success(); } -PortInsertions InferModuleDomains::generalizeModule(FModuleOp module) { +PortInsertions +InferModuleDomains::generalizeModule(FModuleOp module, + PortInsertions &insertions) { // If the port is hardware, we have to check the associated row of // domains. If any associated domain is a variable, we solve the variable // by generalizing the module with an additional input domain port. If any // associated domain is defined internally to the module, we have to add // an output domain port, to allow the domain to escape. - SmallVector> insertions; DenseMap pendingSolutions; llvm::MapVector pendingExports; @@ -1005,6 +1067,38 @@ LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { return success(); } +LogicalResult +InferModuleDomains::updateInstances(InstanceGraphNode *node, + const PortInsertions &insertions) { + // for (auto *use : node->uses()) { + // auto *op = use->getOperation(); + // if (auto inst = dyn_cast(op)) + // if (failed(updateInstance(inst, insertions))) + // return failure(); + + // if (auto inst = dyn_cast(op)) + // if (failed(updateInstance(inst, insertions))) + // return failure(); + + // op->emitError("don't know how to update this instances"); + // return failure(); + // } + return success(); +} + +// LogicalResult +// InferModuleDomains::updateInstance(InstanceOp op, +// const PortInsertions &insertions) { +// op.cloneAndInsertPorts(insertions); +// } + +// LogicalResult +// InferModuleDomains::updateInstance(InstanceChoiceOp op, +// const PortInsertions &insertions) { +// /// Only update the instance choice op if we are the default. +// if (auto) +// } + LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, Value rhs) { LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; @@ -1242,11 +1336,9 @@ void InferDomainsPass::runOnOperation() { for (auto *root : instanceGraph) { for (auto *node : llvm::post_order_ext(root, visited)) { auto *op = node->getModule(); - if (auto module = llvm::dyn_cast_if_present(op)) { - if (failed(InferModuleDomains::run(circuitInfo, module))) { - signalPassFailure(); - return; - } + if (failed(InferModuleDomains::run(circuitInfo, node))) { + signalPassFailure(); + return; } } } From 7b8d930de3e1132b4638233844e73c890b630218 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 15 Oct 2025 16:10:21 -0400 Subject: [PATCH 13/20] More fixes --- .../FIRRTL/Transforms/InferDomains.cpp | 214 +++++++++--------- test/Dialect/FIRRTL/infer-domains-errors.mlir | 6 +- test/Dialect/FIRRTL/infer-domains.mlir | 30 +-- 3 files changed, 120 insertions(+), 130 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index e8e5e9ca44eb..efa3943f297e 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -60,10 +60,10 @@ static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { /// From a domain info attribute, get the row of associated domains for a /// hardware value at index i. -static ArrayAttr getPortDomainAssociation(ArrayAttr info, size_t i) { +static auto getPortDomainAssociation(ArrayAttr info, size_t i) { if (info.empty()) - return info; - return cast(info[i]); + return info.getAsRange(); + return cast(info[i]).getAsRange(); } /// Return true if the value is a port on the module. @@ -161,8 +161,12 @@ struct CircuitDomainInfo { //====-------------------------------------------------------------------------- namespace { +/// Information about the changes made to the interface of a module, which can +/// be replayed onto an instance. struct ModuleUpdateInfo { + /// The updated domain information for a module. ArrayAttr portDomainInfo; + /// The domain ports which have been inserted into a module. PortInsertions portInsertions; }; } // namespace @@ -183,32 +187,6 @@ static bool operator==(const ModuleUpdateInfo &lhs, return true; } -static void updateInstance(OpBuilder &builder, const ModuleUpdateInfo &info, - InstanceOp op) { - auto clone = op.insertPorts(info.portInsertions); - clone.setDomainInfoAttr(info.portDomainInfo); -} - -static void updateInstance(OpBuilder &builder, const ModuleUpdateInfo &info, - InstanceChoiceOp op) { - auto clone = op.insertPorts(info.portInsertions); - clone.setDomainInfoAttr(info.portDomainInfo); -} - -using ModuleUpdateInfoTable = DenseMap; - -LogicalResult updateInstance(const ModuleUpdateInfoTable &table, - InstanceOp op) { - return success(); -} - -LogicalResult updateInstance(const ModuleUpdateInfoTable &table, - InstanceChoiceOp op) { - // verify that all modules have the same update. - auto _mod = op.getDefaultTargetAttr(); - return success(); -} - //====-------------------------------------------------------------------------- // Terms: Syntax for unifying domain and domain-rows. //====-------------------------------------------------------------------------- @@ -413,13 +391,24 @@ class InferModuleDomains { /// value of domains, defined within the body of the module. LogicalResult processBody(FModuleOp); - /// Record the domain associations of any operands or results. + /// Record the domain associations of any operands or results, updating the op + /// if necessary. LogicalResult processOp(Operation *); LogicalResult processOp(InstanceOp); + LogicalResult processOp(InstanceChoiceOp); LogicalResult processOp(UnsafeDomainCastOp); LogicalResult processOp(DomainDefineOp); + LogicalResult processOp(WhenOp); + + /// Apply the port changes of a module onto an instance-like op. + template + T updateInstancePorts(T op, const ModuleUpdateInfo &update); + + /// Record the domain associations of the ports of an instance-like op. + template + void processInstancePorts(T op); - LogicalResult updateModule(FModuleOp, PortInsertions &); + LogicalResult updateModule(FModuleOp); /// Build a table of exported domains: a map from domains defined internally, /// to their set of aliasing output ports. @@ -449,10 +438,7 @@ class InferModuleDomains { /// Add domain ports for any uninferred domains associated to hardware. /// Returns the inserted ports, which will be used later to generalize the /// instances of this module. - PortInsertions generalizeModule(FModuleOp, PortInsertions &); - - LogicalResult updateInstances(InstanceGraphNode *node, - const PortInsertions &); + void generalizeModule(FModuleOp); void generalizeInstance(InstanceOp, const PortInsertions &); @@ -487,6 +473,10 @@ class InferModuleDomains { /// domains. If no mapping has been defined, returns nullptr. Term *getOptDomainAssociation(Value) const; + /// Get the changes made to a module, by the module's name. + const ModuleUpdateInfo &getModuleUpdateInfo(StringAttr name) const; + const ModuleUpdateInfo &getModuleUpdateInfo(FlatSymbolRefAttr ref) const; + /// Allocate a row, where each domain is a variable. RowTerm *allocateRow(); @@ -522,6 +512,9 @@ class InferModuleDomains { /// A map from local domain definition to its aliasing output ports. DenseMap> exportTable; + /// A map from module name to the updates made to it. + DenseMap moduleUpdateInfoTable; + /// A boolean tracking if a non-fatal error occurred, or not. bool ok = true; }; @@ -558,11 +551,7 @@ LogicalResult InferModuleDomains::operator()(InstanceGraphNode *node) { llvm::errs() << " " << association.second << "\n"; }); - PortInsertions insertions; - if (failed(updateModule(module, insertions))) - return failure(); - - if (failed(updateInstances(node, insertions))) + if (failed(updateModule(module))) return failure(); return llvm::success(ok); @@ -593,7 +582,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { continue; SmallVector elements(circuitInfo.getNumDomains()); - for (auto domainPortIndexAttr : portDomains.getAsRange()) { + for (auto domainPortIndexAttr : portDomains) { auto domainPortIndex = domainPortIndexAttr.getUInt(); auto domainTypeID = domainTypeIDTable[domainPortIndex]; auto domainValue = module.getArgument(domainPortIndex); @@ -635,7 +624,7 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { return processOp(def); // For all other operations (including connections), propagate domains from - // operands to results This is a conservative approach - all operands and + // operands to results. This is a conservative approach - all operands and // results share the same domain associations. Value lhs; for (auto rhs : op->getOperands()) { @@ -656,6 +645,58 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { } LogicalResult InferModuleDomains::processOp(InstanceOp op) { + const auto &update = getModuleUpdateInfo(op.getReferencedModuleNameAttr()); + op = updateInstancePorts(op, update); + processInstancePorts(op); + return success(); +} + +LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { + const auto &update = getModuleUpdateInfo(op.getDefaultTargetAttr().getAttr()); + op = updateInstancePorts(op, update); + processInstancePorts(op); + return success(); +} + +LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { + auto domains = op.getDomains(); + if (domains.empty()) + return unifyAssociations(op, op.getInput(), op.getResult()); + + auto input = op.getInput(); + RowTerm *inputRow = getDomainAssociationAsRow(input); + SmallVector elements(inputRow->elements); + for (auto domain : op.getDomains()) { + auto index = circuitInfo.getDomainTypeID(domain); + elements[index] = getTermForDomain(domain); + } + + auto *row = allocateRow(elements); + setDomainAssociation(op.getResult(), row); + return success(); +} + +LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { + auto src = op.getSrc(); + auto dst = op.getDest(); + auto *term = getTermForDomain(src); + setTermForDomain(dst, term); + return success(); +} + +LogicalResult InferModuleDomains::processOp(WhenOp op) { return failure(); } + +template +T InferModuleDomains::updateInstancePorts(T op, + const ModuleUpdateInfo &update) { + auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions); + clone.setDomainInfoAttr(update.portDomainInfo); + op->erase(); + return clone; +} + +template +void InferModuleDomains::processInstancePorts(T op) { DenseMap domainPortTypeIDTable; auto domainInfo = op.getDomainInfoAttr(); for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { @@ -679,8 +720,7 @@ LogicalResult InferModuleDomains::processOp(InstanceOp op) { // port with a row of domains, where each element is derived from the domain // associations recorded in the domain info attribute of the instance. SmallVector elements(circuitInfo.getNumDomains()); - auto associations = - getPortDomainAssociation(domainInfo, i).getAsRange(); + auto associations = getPortDomainAssociation(domainInfo, i); for (auto domainPortIndexAttr : associations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); auto typeID = domainPortTypeIDTable[domainPortIndex]; @@ -695,41 +735,12 @@ LogicalResult InferModuleDomains::processOp(InstanceOp op) { setDomainAssociation(port, allocateRow(elements)); } - - return success(); } -LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { - auto domains = op.getDomains(); - if (domains.empty()) - return unifyAssociations(op, op.getInput(), op.getResult()); - - auto input = op.getInput(); - RowTerm *inputRow = getDomainAssociationAsRow(input); - SmallVector elements(inputRow->elements); - for (auto domain : op.getDomains()) { - auto index = circuitInfo.getDomainTypeID(domain); - elements[index] = getTermForDomain(domain); - } - - auto *row = allocateRow(elements); - setDomainAssociation(op.getResult(), row); - return success(); -} - -LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { - auto src = op.getSrc(); - auto dst = op.getDest(); - auto *term = getTermForDomain(src); - setTermForDomain(dst, term); - return success(); -} - -LogicalResult InferModuleDomains::updateModule(FModuleOp op, - PortInsertions &insertions) { +LogicalResult InferModuleDomains::updateModule(FModuleOp op) { initializeExportTable(op); - generalizeModule(op, insertions); + generalizeModule(op); if (failed(updatePortDomainAssociations(op))) return failure(); @@ -863,11 +874,11 @@ InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, size_t portIndex) { SmallVector result(circuitInfo.getNumDomains()); auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); - for (auto domainPortIndexAttr : oldAssociations.getAsRange()) { + for (auto domainPortIndexAttr : oldAssociations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); - auto DomainTypeID = + auto domainTypeID = circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); - result[DomainTypeID] = domainPortIndexAttr; + result[domainTypeID] = domainPortIndexAttr; }; return result; } @@ -903,9 +914,8 @@ LogicalResult InferModuleDomains::getExportingPort(FModuleOp module, return success(); } -PortInsertions -InferModuleDomains::generalizeModule(FModuleOp module, - PortInsertions &insertions) { +void InferModuleDomains::generalizeModule(FModuleOp module) { + PortInsertions insertions; // If the port is hardware, we have to check the associated row of // domains. If any associated domain is a variable, we solve the variable // by generalizing the module with an additional input domain port. If any @@ -990,7 +1000,7 @@ InferModuleDomains::generalizeModule(FModuleOp module, } } - // Put the ports in place. + // Put the domain ports in place. module.insertPorts(insertions); // Solve the variables. @@ -1008,7 +1018,9 @@ InferModuleDomains::generalizeModule(FModuleOp module, setTermForDomain(port, allocate(value)); } - return insertions; + // Record the insertions, so we can replay them on instances later. + auto &info = moduleUpdateInfoTable[module.getNameAttr()]; + info.portInsertions = std::move(insertions); } LogicalResult @@ -1067,37 +1079,15 @@ LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { return success(); } -LogicalResult -InferModuleDomains::updateInstances(InstanceGraphNode *node, - const PortInsertions &insertions) { - // for (auto *use : node->uses()) { - // auto *op = use->getOperation(); - // if (auto inst = dyn_cast(op)) - // if (failed(updateInstance(inst, insertions))) - // return failure(); - - // if (auto inst = dyn_cast(op)) - // if (failed(updateInstance(inst, insertions))) - // return failure(); - - // op->emitError("don't know how to update this instances"); - // return failure(); - // } - return success(); +const ModuleUpdateInfo & +InferModuleDomains::getModuleUpdateInfo(StringAttr name) const { + return moduleUpdateInfoTable.at(name); } -// LogicalResult -// InferModuleDomains::updateInstance(InstanceOp op, -// const PortInsertions &insertions) { -// op.cloneAndInsertPorts(insertions); -// } - -// LogicalResult -// InferModuleDomains::updateInstance(InstanceChoiceOp op, -// const PortInsertions &insertions) { -// /// Only update the instance choice op if we are the default. -// if (auto) -// } +const ModuleUpdateInfo & +InferModuleDomains::getModuleUpdateInfo(FlatSymbolRefAttr ref) const { + return getModuleUpdateInfo(ref.getAttr()); +} LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, Value rhs) { diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index 74a70cfd0714..c396abf527d1 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -2,7 +2,7 @@ // Port annotated with same domain type twice. firrtl.circuit "DomainCrossOnPort" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @DomainCrossOnPort( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain, @@ -17,7 +17,7 @@ firrtl.circuit "DomainCrossOnPort" { // Illegal domain crossing - connect op. firrtl.circuit "IllegalDomainCrossing" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @IllegalDomainCrossing( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain, @@ -35,7 +35,7 @@ firrtl.circuit "IllegalDomainCrossing" { // Illegal domain crossing at matchingconnect op. firrtl.circuit "IllegalDomainCrossing" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @IllegalDomainCrossing( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain, diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index 0894a0425917..5c7d20eef05b 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -2,7 +2,7 @@ // Legal domain usage - no crossing. firrtl.circuit "LegalDomains" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @LegalDomains( in %A: !firrtl.domain of @ClockDomain, in %a: !firrtl.uint<1> domains [%A], @@ -18,7 +18,7 @@ firrtl.circuit "LegalDomains" { // Domain inference through connections. firrtl.circuit "DomainInference" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @DomainInference( in %A: !firrtl.domain of @ClockDomain, in %a: !firrtl.uint<1> domains [%A], @@ -40,7 +40,7 @@ firrtl.circuit "DomainInference" { // Unsafe domain cast firrtl.circuit "UnsafeDomainCast" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @UnsafeDomainCast( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain, @@ -60,8 +60,8 @@ firrtl.circuit "UnsafeDomainCast" { // Domain sequence matching. firrtl.circuit "LegalSequences" { - firrtl.domain @ClockDomain {} - firrtl.domain @PowerDomain {} + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain firrtl.module @LegalSequences( in %C: !firrtl.domain of @ClockDomain, in %P: !firrtl.domain of @PowerDomain, @@ -76,8 +76,8 @@ firrtl.circuit "LegalSequences" { // Domain sequence order equivalence - should be legal firrtl.circuit "SequenceOrderEquivalence" { - firrtl.domain @ClockDomain {} - firrtl.domain @PowerDomain {} + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain firrtl.module @SequenceOrderEquivalence( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @PowerDomain, @@ -94,8 +94,8 @@ firrtl.circuit "SequenceOrderEquivalence" { // Domain sequence inference firrtl.circuit "SequenceInference" { - firrtl.domain @ClockDomain {} - firrtl.domain @PowerDomain {} + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain firrtl.module @SequenceInference( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @PowerDomain, @@ -116,7 +116,7 @@ firrtl.circuit "SequenceInference" { // Domain duplicate equivalence - should be legal. firrtl.circuit "DuplicateDomainEquivalence" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @DuplicateDomainEquivalence( in %A: !firrtl.domain of @ClockDomain, in %a: !firrtl.uint<1> domains [%A, %A], @@ -131,8 +131,8 @@ firrtl.circuit "DuplicateDomainEquivalence" { // Unsafe domain cast with sequences firrtl.circuit "UnsafeSequenceCast" { - firrtl.domain @ClockDomain {} - firrtl.domain @PowerDomain {} + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain firrtl.module @UnsafeSequenceCast( in %C1: !firrtl.domain of @ClockDomain, @@ -150,7 +150,7 @@ firrtl.circuit "UnsafeSequenceCast" { // Different port types domain inference. firrtl.circuit "DifferentPortTypes" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @DifferentPortTypes( in %A: !firrtl.domain of @ClockDomain, in %uint_input: !firrtl.uint<8> domains [%A], @@ -169,7 +169,7 @@ firrtl.circuit "DifferentPortTypes" { // CHECK-LABEL: DomainInferenceThroughWires firrtl.circuit "DomainInferenceThroughWires" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @DomainInferenceThroughWires( in %A: !firrtl.domain of @ClockDomain, in %input: !firrtl.uint<1> domains [%A], @@ -189,7 +189,7 @@ firrtl.circuit "DomainInferenceThroughWires" { // Register inference/ firrtl.circuit "RegisterInference" { - firrtl.domain @ClockDomain {} + firrtl.domain @ClockDomain firrtl.module @RegisterInference( in %A: !firrtl.domain of @ClockDomain, in %clock: !firrtl.clock domains [%A], From a3dd2916a593d778c39d954cc49f0d0b9b6d0721 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 15 Oct 2025 17:15:15 -0400 Subject: [PATCH 14/20] More updates --- .../FIRRTL/Transforms/InferDomains.cpp | 141 +++++++----------- test/Dialect/FIRRTL/infer-domains.mlir | 45 +++--- 2 files changed, 82 insertions(+), 104 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index efa3943f297e..f5f89fc2c9df 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -80,7 +80,7 @@ static bool isPort(FModuleOp module, Value value) { } //====-------------------------------------------------------------------------- -// CircuitDomainInfo: Information about the domains declared in a circuit. +// Circuit-wide state. //====-------------------------------------------------------------------------- /// Each declared domain in the circuit is assigned an index, based on the order @@ -93,11 +93,7 @@ using DomainTypeID = size_t; /// of a domain. namespace { struct CircuitDomainInfo { - static CircuitDomainInfo get(CircuitOp circuit) { - CircuitDomainInfo info; - info.processCircuit(circuit); - return info; - } + CircuitDomainInfo(CircuitOp circuit) { processCircuit(circuit); } ArrayRef getDomains() const { return domainTable; } size_t getNumDomains() const { return domainTable.size(); } @@ -106,16 +102,20 @@ struct CircuitDomainInfo { DomainTypeID getDomainTypeID(DomainOp op) const { return typeIDTable.at(op.getNameAttr()); } + DomainTypeID getDomainTypeID(StringAttr name) const { return typeIDTable.at(name); } + DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const { return getDomainTypeID(ref.getAttr()); } + DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const { auto name = getDomainPortTypeName(info, i); return getDomainTypeID(name); } + DomainTypeID getDomainTypeID(Value value) const { assert(isa(value.getType())); if (auto arg = dyn_cast(value)) { @@ -154,13 +154,7 @@ struct CircuitDomainInfo { /// A map from domain name to type ID. DenseMap typeIDTable; }; -} // namespace -//====-------------------------------------------------------------------------- -// ModuleUpdateInfo: Summary of about how a module's interface has changed. -//====-------------------------------------------------------------------------- - -namespace { /// Information about the changes made to the interface of a module, which can /// be replayed onto an instance. struct ModuleUpdateInfo { @@ -169,23 +163,15 @@ struct ModuleUpdateInfo { /// The domain ports which have been inserted into a module. PortInsertions portInsertions; }; -} // namespace -static bool operator==(const ModuleUpdateInfo &lhs, - const ModuleUpdateInfo &rhs) { - if (lhs.portDomainInfo != rhs.portDomainInfo) - return false; +struct GlobalState { + GlobalState(CircuitOp circuit) : circuitInfo(circuit) {} - if (lhs.portInsertions.size() != rhs.portInsertions.size()) - return false; - - for (size_t i = 0, e = lhs.portInsertions.size(); i < e; ++i) { - if (lhs.portInsertions != rhs.portInsertions) - return false; - } + CircuitDomainInfo circuitInfo; + DenseMap moduleUpdateTable; +}; - return true; -} +} // namespace //====-------------------------------------------------------------------------- // Terms: Syntax for unifying domain and domain-rows. @@ -374,14 +360,14 @@ namespace { class InferModuleDomains { public: /// Run infer-domains on a module. - static LogicalResult run(const CircuitDomainInfo &, InstanceGraphNode *node); + static LogicalResult run(GlobalState &, FModuleOp); private: /// Initialize module-level state. - InferModuleDomains(const CircuitDomainInfo &); + InferModuleDomains(GlobalState &); /// Execute on the given module. - LogicalResult operator()(InstanceGraphNode *node); + LogicalResult operator()(FModuleOp); /// Record the domain associations of hardware ports, and record the /// underlying value of output domain ports. @@ -473,10 +459,6 @@ class InferModuleDomains { /// domains. If no mapping has been defined, returns nullptr. Term *getOptDomainAssociation(Value) const; - /// Get the changes made to a module, by the module's name. - const ModuleUpdateInfo &getModuleUpdateInfo(StringAttr name) const; - const ModuleUpdateInfo &getModuleUpdateInfo(FlatSymbolRefAttr ref) const; - /// Allocate a row, where each domain is a variable. RowTerm *allocateRow(); @@ -498,7 +480,7 @@ class InferModuleDomains { void emitPortDomainCrossingError(BlockArgument, size_t, Term *, Term *) const; /// Information about the domains in a circuit. - const CircuitDomainInfo &circuitInfo; + GlobalState &globals; /// Term allocator. llvm::BumpPtrAllocator allocator; @@ -512,27 +494,19 @@ class InferModuleDomains { /// A map from local domain definition to its aliasing output ports. DenseMap> exportTable; - /// A map from module name to the updates made to it. - DenseMap moduleUpdateInfoTable; - /// A boolean tracking if a non-fatal error occurred, or not. bool ok = true; }; } // namespace -LogicalResult InferModuleDomains::run(const CircuitDomainInfo &circuitInfo, - InstanceGraphNode *node) { - return InferModuleDomains(circuitInfo)(node); +LogicalResult InferModuleDomains::run(GlobalState &globals, FModuleOp module) { + return InferModuleDomains(globals)(module); } -InferModuleDomains::InferModuleDomains(const CircuitDomainInfo &circuitInfo) - : circuitInfo(circuitInfo) {} - -LogicalResult InferModuleDomains::operator()(InstanceGraphNode *node) { - auto module = dyn_cast(node->getOperation()); - if (!module) - return success(); +InferModuleDomains::InferModuleDomains(GlobalState &globals) + : globals(globals) {} +LogicalResult InferModuleDomains::operator()(FModuleOp module) { LLVM_DEBUG( llvm::errs() << "================================================\n"; llvm::errs() << "infer module domains: " << module.getModuleName() @@ -568,7 +542,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { // This is a domain port. if (isa(port.getType())) { - auto typeID = circuitInfo.getDomainTypeID(portDomainInfo, i); + auto typeID = globals.circuitInfo.getDomainTypeID(portDomainInfo, i); domainTypeIDTable[i] = typeID; if (module.getPortDirection(i) == Direction::In) { setTermForDomain(port, allocate(port)); @@ -581,7 +555,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { if (portDomains.empty()) continue; - SmallVector elements(circuitInfo.getNumDomains()); + SmallVector elements(globals.circuitInfo.getNumDomains()); for (auto domainPortIndexAttr : portDomains) { auto domainPortIndex = domainPortIndexAttr.getUInt(); auto domainTypeID = domainTypeIDTable[domainPortIndex]; @@ -645,14 +619,16 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { } LogicalResult InferModuleDomains::processOp(InstanceOp op) { - const auto &update = getModuleUpdateInfo(op.getReferencedModuleNameAttr()); + const auto &update = + globals.moduleUpdateTable.at(op.getReferencedModuleNameAttr()); op = updateInstancePorts(op, update); processInstancePorts(op); return success(); } LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { - const auto &update = getModuleUpdateInfo(op.getDefaultTargetAttr().getAttr()); + const auto &update = + globals.moduleUpdateTable.at(op.getDefaultTargetAttr().getAttr()); op = updateInstancePorts(op, update); processInstancePorts(op); return success(); @@ -667,8 +643,8 @@ LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { RowTerm *inputRow = getDomainAssociationAsRow(input); SmallVector elements(inputRow->elements); for (auto domain : op.getDomains()) { - auto index = circuitInfo.getDomainTypeID(domain); - elements[index] = getTermForDomain(domain); + auto typeID = globals.circuitInfo.getDomainTypeID(domain); + elements[typeID] = getTermForDomain(domain); } auto *row = allocateRow(elements); @@ -705,7 +681,7 @@ void InferModuleDomains::processInstancePorts(T op) { LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); if (isa(port.getType())) { - auto typeID = circuitInfo.getDomainTypeID(domainInfo, i); + auto typeID = globals.circuitInfo.getDomainTypeID(domainInfo, i); domainPortTypeIDTable[i] = typeID; if (op.getPortDirection(i) == Direction::Out) { setTermForDomain(port, allocate(port)); @@ -719,7 +695,7 @@ void InferModuleDomains::processInstancePorts(T op) { // This is a port, which may have explicit domain information. Associate the // port with a row of domains, where each element is derived from the domain // associations recorded in the domain info attribute of the instance. - SmallVector elements(circuitInfo.getNumDomains()); + SmallVector elements(globals.circuitInfo.getNumDomains()); auto associations = getPortDomainAssociation(domainInfo, i); for (auto domainPortIndexAttr : associations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); @@ -774,7 +750,7 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { // to form the new port domain information for the module by examining the // the associated domains of each port. auto *context = module.getContext(); - auto numDomains = circuitInfo.getNumDomains(); + auto numDomains = globals.circuitInfo.getNumDomains(); auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); auto oldModuleDomainInfo = module.getDomainInfoAttr(); auto numPorts = module.getNumPorts(); @@ -820,10 +796,10 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { if (isa(type)) { auto associations = copyPortDomainAssociations(oldModuleDomainInfo, i); auto *row = getDomainAssociationAsRow(port); - for (size_t DomainTypeID = 0; DomainTypeID < numDomains; ++DomainTypeID) { - if (associations[DomainTypeID]) + for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { + if (associations[domainTypeID]) continue; - auto domain = cast(find(row->elements[DomainTypeID]))->value; + auto domain = cast(find(row->elements[domainTypeID]))->value; size_t domainPortIndex = 0; if (auto arg = dyn_cast(domain)) { if (arg.getOwner()->getParentOp() == module) { @@ -851,7 +827,7 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { return failure(); } } - associations[DomainTypeID] = IntegerAttr::get( + associations[domainTypeID] = IntegerAttr::get( IntegerType::get(context, 32, IntegerType::Unsigned), domainPortIndex); } @@ -866,18 +842,23 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { auto newModuleDomainInfoAttr = ArrayAttr::get(module.getContext(), newModuleDomainInfo); module.setDomainInfoAttr(newModuleDomainInfoAttr); + + // record the domain info for replaying on instances. + auto &update = globals.moduleUpdateTable[module.getNameAttr()]; + update.portDomainInfo = newModuleDomainInfoAttr; + return success(); } SmallVector InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, size_t portIndex) { - SmallVector result(circuitInfo.getNumDomains()); + SmallVector result(globals.circuitInfo.getNumDomains()); auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); for (auto domainPortIndexAttr : oldAssociations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); auto domainTypeID = - circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); + globals.circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); result[domainTypeID] = domainPortIndexAttr; }; return result; @@ -951,7 +932,7 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { continue; // We must insert a new output domain port. - auto domainDecl = circuitInfo.getDomain(typeID); + auto domainDecl = globals.circuitInfo.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); auto portInsertionPoint = i; @@ -977,7 +958,7 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { continue; // insert a new input domain port for the variable. - auto domainDecl = circuitInfo.getDomain(typeID); + auto domainDecl = globals.circuitInfo.getDomain(typeID); auto domainName = domainDecl.getNameAttr(); auto portInsertionPoint = i; @@ -1019,8 +1000,8 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { } // Record the insertions, so we can replay them on instances later. - auto &info = moduleUpdateInfoTable[module.getNameAttr()]; - info.portInsertions = std::move(insertions); + auto &update = globals.moduleUpdateTable[module.getNameAttr()]; + update.portInsertions = std::move(insertions); } LogicalResult @@ -1079,16 +1060,6 @@ LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { return success(); } -const ModuleUpdateInfo & -InferModuleDomains::getModuleUpdateInfo(StringAttr name) const { - return moduleUpdateInfoTable.at(name); -} - -const ModuleUpdateInfo & -InferModuleDomains::getModuleUpdateInfo(FlatSymbolRefAttr ref) const { - return getModuleUpdateInfo(ref.getAttr()); -} - LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, Value rhs) { LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; @@ -1219,7 +1190,7 @@ void InferModuleDomains::setDomainAssociation(Value value, Term *term) { RowTerm *InferModuleDomains::allocateRow() { SmallVector elements; - elements.resize(circuitInfo.getNumDomains()); + elements.resize(globals.circuitInfo.getNumDomains()); return allocateRow(elements); } @@ -1270,8 +1241,8 @@ void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, if (auto *row = dyn_cast(term)) { bool first = true; out << "["; - for (size_t i = 0, e = circuitInfo.getNumDomains(); i < e; ++i) { - auto domainOp = circuitInfo.getDomain(i); + for (size_t i = 0, e = globals.circuitInfo.getNumDomains(); i < e; ++i) { + auto domainOp = globals.circuitInfo.getDomain(i); if (!first) { out << ", "; first = false; @@ -1291,7 +1262,7 @@ void InferModuleDomains::emitPortDomainCrossingError(BlockArgument port, VariableIDTable idTable; auto portIndex = port.getArgNumber(); - auto domainOp = circuitInfo.getDomain(domainTypeID); + auto domainOp = globals.circuitInfo.getDomain(domainTypeID); auto domainName = domainOp.getName(); auto diag = emitError(port.getLoc()); @@ -1321,12 +1292,16 @@ void InferDomainsPass::runOnOperation() { LLVM_DEBUG(debugPassHeader(this) << "\n"); auto circuit = getOperation(); auto &instanceGraph = getAnalysis(); - auto circuitInfo = CircuitDomainInfo::get(circuit); + + GlobalState globals(circuit); DenseSet visited; for (auto *root : instanceGraph) { for (auto *node : llvm::post_order_ext(root, visited)) { - auto *op = node->getModule(); - if (failed(InferModuleDomains::run(circuitInfo, node))) { + auto module = dyn_cast(node->getOperation()); + if (!module) + continue; + + if (failed(InferModuleDomains::run(globals, module))) { signalPassFailure(); return; } diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index 5c7d20eef05b..e9e04ed3cd39 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --split-input-file | FileCheck %s +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s | FileCheck %s // Legal domain usage - no crossing. firrtl.circuit "LegalDomains" { @@ -14,8 +14,6 @@ firrtl.circuit "LegalDomains" { } // CHECK-LABEL: firrtl.circuit "LegalDomains" -// ----- - // Domain inference through connections. firrtl.circuit "DomainInference" { firrtl.domain @ClockDomain @@ -36,8 +34,6 @@ firrtl.circuit "DomainInference" { // CHECK-LABEL: firrtl.circuit "DomainInference" // CHECK: out %c: !firrtl.uint<1> domains [%A] -// ----- - // Unsafe domain cast firrtl.circuit "UnsafeDomainCast" { firrtl.domain @ClockDomain @@ -56,8 +52,6 @@ firrtl.circuit "UnsafeDomainCast" { } // CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" -// ----- - // Domain sequence matching. firrtl.circuit "LegalSequences" { firrtl.domain @ClockDomain @@ -72,8 +66,6 @@ firrtl.circuit "LegalSequences" { } } -// ----- - // Domain sequence order equivalence - should be legal firrtl.circuit "SequenceOrderEquivalence" { firrtl.domain @ClockDomain @@ -90,8 +82,6 @@ firrtl.circuit "SequenceOrderEquivalence" { } // CHECK-LABEL: firrtl.circuit "SequenceOrderEquivalence" -// ----- - // Domain sequence inference firrtl.circuit "SequenceInference" { firrtl.domain @ClockDomain @@ -112,8 +102,6 @@ firrtl.circuit "SequenceInference" { } } -// ----- - // Domain duplicate equivalence - should be legal. firrtl.circuit "DuplicateDomainEquivalence" { firrtl.domain @ClockDomain @@ -127,8 +115,6 @@ firrtl.circuit "DuplicateDomainEquivalence" { } } -// ----- - // Unsafe domain cast with sequences firrtl.circuit "UnsafeSequenceCast" { firrtl.domain @ClockDomain @@ -146,9 +132,9 @@ firrtl.circuit "UnsafeSequenceCast" { } } -// ----- - // Different port types domain inference. + +// CHECK-LABEL: DifferentPortTypes firrtl.circuit "DifferentPortTypes" { firrtl.domain @ClockDomain firrtl.module @DifferentPortTypes( @@ -163,8 +149,6 @@ firrtl.circuit "DifferentPortTypes" { } } -// ----- - // Domain inference through wires. // CHECK-LABEL: DomainInferenceThroughWires @@ -185,9 +169,9 @@ firrtl.circuit "DomainInferenceThroughWires" { } } -// ----- +// Register inference. -// Register inference/ +// CHECK-LABEL: RegisterInference firrtl.circuit "RegisterInference" { firrtl.domain @ClockDomain firrtl.module @RegisterInference( @@ -203,3 +187,22 @@ firrtl.circuit "RegisterInference" { firrtl.matchingconnect %q, %r : !firrtl.uint<1> } } + +// Update domain on instance. + +// CHECK-LABEL: InstanceUpdate +firrtl.circuit "InstanceUpdate" { + firrtl.domain @ClockDomain + // CHECK: firrtl.module @Foo(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} From 8ae2ff3540d12ed7e47f76cadd42a0f930172e1f Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 16 Oct 2025 09:39:20 -0400 Subject: [PATCH 15/20] Infer-domains: support instance-choice --- .../FIRRTL/Transforms/InferDomains.cpp | 15 ++++++++--- test/Dialect/FIRRTL/infer-domains.mlir | 26 ++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index f5f89fc2c9df..5b7e4e071e38 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -408,7 +408,9 @@ class InferModuleDomains { /// to fix up any child instance modules. LogicalResult updateDomainAssociationsInBody(FModuleOp); LogicalResult updateOpDomainAssociations(Operation *); - LogicalResult updateOpDomainAssociations(InstanceOp); + + template + LogicalResult updateInstanceDomainAssociations(T op); /// Copy the domain associations from the module domain info attribute into a /// small vector. @@ -592,6 +594,8 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { if (auto instance = dyn_cast(op)) return processOp(instance); + if (auto instance = dyn_cast(op)) + return processOp(instance); if (auto cast = dyn_cast(op)) return processOp(cast); if (auto def = dyn_cast(op)) @@ -1018,12 +1022,15 @@ InferModuleDomains::updateDomainAssociationsInBody(FModuleOp module) { } LogicalResult InferModuleDomains::updateOpDomainAssociations(Operation *op) { - if (auto inst = dyn_cast(op)) - return updateOpDomainAssociations(inst); + if (auto instance = dyn_cast(op)) + return updateInstanceDomainAssociations(instance); + if (auto instance = dyn_cast(op)) + return updateInstanceDomainAssociations(instance); return success(); } -LogicalResult InferModuleDomains::updateOpDomainAssociations(InstanceOp op) { +template +LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { auto *context = op.getContext(); OpBuilder builder(context); builder.setInsertionPointAfter(op); diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index e9e04ed3cd39..1df48356bc4e 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -193,7 +193,7 @@ firrtl.circuit "RegisterInference" { // CHECK-LABEL: InstanceUpdate firrtl.circuit "InstanceUpdate" { firrtl.domain @ClockDomain - // CHECK: firrtl.module @Foo(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { @@ -206,3 +206,27 @@ firrtl.circuit "InstanceUpdate" { firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> } } + +// CHECK-LABEL: InstanceChoiceUpdate +firrtl.circuit "InstanceChoiceUpdate" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + firrtl.module @Bar(in %i : !firrtl.uint<1>) {} + firrtl.module @Baz(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { + %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) + firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} From 4683a28f93b13ab1d1cf391e89350302cd0f5bff Mon Sep 17 00:00:00 2001 From: Robert Young Date: Thu, 16 Oct 2025 11:28:41 -0400 Subject: [PATCH 16/20] Domain inference: support constants, extmodules --- .../FIRRTL/Transforms/InferDomains.cpp | 56 ++++++++++++++++--- test/Dialect/FIRRTL/infer-domains-errors.mlir | 37 ++++++++++++ 2 files changed, 84 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 5b7e4e071e38..edc08880f9e9 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -481,6 +481,11 @@ class InferModuleDomains { void emitPortDomainCrossingError(BlockArgument, size_t, Term *, Term *) const; + /// Emit an error when we fail to infer the concrete domain to drive to a + /// domain port. + template + void emitDomainPortInferenceError(T op, size_t i) const; + /// Information about the domains in a circuit. GlobalState &globals; @@ -608,6 +613,9 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { for (auto rhs : op->getOperands()) { if (!isa(rhs.getType())) continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; if (failed(unifyAssociations(op, lhs, rhs))) return failure(); lhs = rhs; @@ -615,6 +623,9 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { for (auto rhs : op->getResults()) { if (!isa(rhs.getType())) continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; if (failed(unifyAssociations(op, lhs, rhs))) return failure(); lhs = rhs; @@ -623,17 +634,21 @@ LogicalResult InferModuleDomains::processOp(Operation *op) { } LogicalResult InferModuleDomains::processOp(InstanceOp op) { - const auto &update = - globals.moduleUpdateTable.at(op.getReferencedModuleNameAttr()); - op = updateInstancePorts(op, update); + auto module = op.getReferencedModuleNameAttr(); + auto lookup = globals.moduleUpdateTable.find(module); + if (lookup != globals.moduleUpdateTable.end()) + op = updateInstancePorts(op, lookup->second); + processInstancePorts(op); return success(); } LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { - const auto &update = - globals.moduleUpdateTable.at(op.getDefaultTargetAttr().getAttr()); - op = updateInstancePorts(op, update); + auto module = op.getDefaultTargetAttr().getAttr(); + auto lookup = globals.moduleUpdateTable.find(module); + if (lookup != globals.moduleUpdateTable.end()) + op = updateInstancePorts(op, lookup->second); + processInstancePorts(op); return success(); } @@ -782,8 +797,10 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { auto *term = getTermForDomain(port); term = find(term); auto *val = dyn_cast(term); - if (!val) - return module.emitError() << "unable to infer output domain value"; + if (!val) { + emitDomainPortInferenceError(module, i); + return failure(); + } // If the output port is not driven, drive it. if (!driven) { @@ -1058,7 +1075,8 @@ LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { auto value = val->value; DomainDefineOp::create(builder, loc, port, value); } else { - return op.emitError() << "unable to infer input domain value"; + emitDomainPortInferenceError(op, i); + return failure(); } } } @@ -1284,6 +1302,26 @@ void InferModuleDomains::emitPortDomainCrossingError(BlockArgument port, render(note2, term2); } +template +void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { + auto name = op.getPortName(i); + auto diag = emitError(op->getLoc()); + auto info = op.getDomainInfo(); + diag << "unable to infer value for domain port " << name; + for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { + if (auto assocs = dyn_cast(info[j])) { + for (auto assoc : assocs) { + if (i == cast(assoc).getValue()) { + auto name = op.getPortName(j); + auto loc = op.getPortLocation(j); + diag.attachNote(loc) << "associated with hardware port " << name; + break; + } + } + } + } +} + //===--------------------------------------------------------------------------- // InferDomainsPass: Top-level pass implementation. //===--------------------------------------------------------------------------- diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index c396abf527d1..2777f894ba37 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -48,3 +48,40 @@ firrtl.circuit "IllegalDomainCrossing" { firrtl.matchingconnect %b, %a : !firrtl.uint<1> } } + +// ----- + +// Unable to infer domain of port, when port is driven by constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<1>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> + } +} + +// ----- + +// Unable to infer domain of port, when port is driven by arithmetic on constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<2>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) + firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> + } +} + +// ----- + +// Incomplete extmodule domain information. \ No newline at end of file From 16636d3865603545a0510eb16aac41087eb17690 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 17 Oct 2025 09:18:23 -0400 Subject: [PATCH 17/20] More fixes --- include/circt/Firtool/Firtool.h | 8 +++ .../FIRRTL/Transforms/InferDomains.cpp | 60 ++++++++++++------- lib/Firtool/Firtool.cpp | 11 +++- test/Dialect/FIRRTL/infer-domains-errors.mlir | 14 ++++- test/Dialect/FIRRTL/infer-domains.mlir | 18 ++++++ 5 files changed, 88 insertions(+), 23 deletions(-) diff --git a/include/circt/Firtool/Firtool.h b/include/circt/Firtool/Firtool.h index 28e5f6d5d897..5069dca326d2 100644 --- a/include/circt/Firtool/Firtool.h +++ b/include/circt/Firtool/Firtool.h @@ -144,6 +144,8 @@ class FirtoolOptions { bool getEmitAllBindFiles() const { return emitAllBindFiles; } + bool shouldInferDomains() const { return inferDomains; } + // Setters, used by the CAPI FirtoolOptions &setOutputFilename(StringRef name) { outputFilename = name; @@ -393,6 +395,11 @@ class FirtoolOptions { return *this; } + FirtoolOptions &setInferDomains(bool value) { + inferDomains = value; + return *this; + } + private: std::string outputFilename; @@ -447,6 +454,7 @@ class FirtoolOptions { bool lintStaticAsserts; bool lintXmrsInDesign; bool emitAllBindFiles; + bool inferDomains; }; void registerFirtoolCLOptions(); diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index edc08880f9e9..fa1054ea74a3 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -384,7 +384,6 @@ class InferModuleDomains { LogicalResult processOp(InstanceChoiceOp); LogicalResult processOp(UnsafeDomainCastOp); LogicalResult processOp(DomainDefineOp); - LogicalResult processOp(WhenOp); /// Apply the port changes of a module onto an instance-like op. template @@ -392,7 +391,7 @@ class InferModuleDomains { /// Record the domain associations of the ports of an instance-like op. template - void processInstancePorts(T op); + LogicalResult processInstancePorts(T op); LogicalResult updateModule(FModuleOp); @@ -408,7 +407,7 @@ class InferModuleDomains { /// to fix up any child instance modules. LogicalResult updateDomainAssociationsInBody(FModuleOp); LogicalResult updateOpDomainAssociations(Operation *); - + template LogicalResult updateInstanceDomainAssociations(T op); @@ -638,9 +637,7 @@ LogicalResult InferModuleDomains::processOp(InstanceOp op) { auto lookup = globals.moduleUpdateTable.find(module); if (lookup != globals.moduleUpdateTable.end()) op = updateInstancePorts(op, lookup->second); - - processInstancePorts(op); - return success(); + return processInstancePorts(op); } LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { @@ -648,9 +645,7 @@ LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { auto lookup = globals.moduleUpdateTable.find(module); if (lookup != globals.moduleUpdateTable.end()) op = updateInstancePorts(op, lookup->second); - - processInstancePorts(op); - return success(); + return processInstancePorts(op); } LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { @@ -674,13 +669,22 @@ LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { auto src = op.getSrc(); auto dst = op.getDest(); - auto *term = getTermForDomain(src); - setTermForDomain(dst, term); - return success(); + auto *srcTerm = getTermForDomain(src); + auto *dstTerm = getTermForDomain(dst); + if (failed(unify(dstTerm, srcTerm))) { + VariableIDTable idTable; + auto diag = op->emitOpError("failed to propagate source to destination"); + auto ¬e1 = diag.attachNote(); + note1 << "destination has underlying value: "; + render(note1, idTable, dstTerm); + + auto ¬e2 = diag.attachNote(src.getLoc()); + note2 << "source has underlying value: "; + render(note2, idTable, srcTerm); + } + return unify(dstTerm, srcTerm); } -LogicalResult InferModuleDomains::processOp(WhenOp op) { return failure(); } - template T InferModuleDomains::updateInstancePorts(T op, const ModuleUpdateInfo &update) { @@ -691,7 +695,9 @@ T InferModuleDomains::updateInstancePorts(T op, } template -void InferModuleDomains::processInstancePorts(T op) { +LogicalResult InferModuleDomains::processInstancePorts(T op) { + auto circuitInfo = globals.circuitInfo; + auto numDomainTypes = circuitInfo.getNumDomains(); DenseMap domainPortTypeIDTable; auto domainInfo = op.getDomainInfoAttr(); for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { @@ -700,7 +706,7 @@ void InferModuleDomains::processInstancePorts(T op) { LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); if (isa(port.getType())) { - auto typeID = globals.circuitInfo.getDomainTypeID(domainInfo, i); + auto typeID = circuitInfo.getDomainTypeID(domainInfo, i); domainPortTypeIDTable[i] = typeID; if (op.getPortDirection(i) == Direction::Out) { setTermForDomain(port, allocate(port)); @@ -714,7 +720,7 @@ void InferModuleDomains::processInstancePorts(T op) { // This is a port, which may have explicit domain information. Associate the // port with a row of domains, where each element is derived from the domain // associations recorded in the domain info attribute of the instance. - SmallVector elements(globals.circuitInfo.getNumDomains()); + SmallVector elements(numDomainTypes); auto associations = getPortDomainAssociation(domainInfo, i); for (auto domainPortIndexAttr : associations) { auto domainPortIndex = domainPortIndexAttr.getUInt(); @@ -723,13 +729,25 @@ void InferModuleDomains::processInstancePorts(T op) { elements[typeID] = term; } - // Since we are processing bottom-up, we must have complete domain info - // for each port on the instance. - for (auto *element : elements) - assert(element && "must have complete domain information."); + // Confirm that we have complete domain information for the port. We can be + // missing information if, for example, this was an instance of an + // extmodule. + for (size_t domainTypeID = 0; domainTypeID < numDomainTypes; + ++domainTypeID) { + if (elements[domainTypeID]) + continue; + auto domainDecl = circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto portName = op.getPortName(i); + op->emitOpError() << "missing " << domainName << " association for port " + << portName; + return failure(); + } setDomainAssociation(port, allocateRow(elements)); } + + return success(); } LogicalResult InferModuleDomains::updateModule(FModuleOp op) { diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index 6d846422507b..602053ed22fe 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -49,6 +49,9 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, pm.nest().nest().addPass( firrtl::createLowerIntrinsics()); + if (opt.shouldInferDomains()) + pm.nest().addPass(firrtl::createInferDomains()); + return success(); } @@ -757,6 +760,11 @@ struct FirtoolCmdOptions { llvm::cl::desc("Emit bindfiles for private modules"), llvm::cl::init(false)}; + llvm::cl::opt inferDomains{ + "infer-domains", + llvm::cl::desc("Enable domain inference and checking"), + llvm::cl::init(false)}; + //===----------------------------------------------------------------------=== // Lint options //===----------------------------------------------------------------------=== @@ -808,7 +816,7 @@ circt::firtool::FirtoolOptions::FirtoolOptions() disableCSEinClasses(false), selectDefaultInstanceChoice(false), symbolicValueLowering(verif::SymbolicValueLowering::ExtModule), disableWireElimination(false), lintStaticAsserts(true), - lintXmrsInDesign(true), emitAllBindFiles(false) { + lintXmrsInDesign(true), emitAllBindFiles(false), inferDomains(false) { if (!clOptions.isConstructed()) return; outputFilename = clOptions->outputFilename; @@ -861,4 +869,5 @@ circt::firtool::FirtoolOptions::FirtoolOptions() lintStaticAsserts = clOptions->lintStaticAsserts; lintXmrsInDesign = clOptions->lintXmrsInDesign; emitAllBindFiles = clOptions->emitAllBindFiles; + inferDomains = clOptions->inferDomains; } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index 2777f894ba37..5eb07891753f 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -84,4 +84,16 @@ firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { // ----- -// Incomplete extmodule domain information. \ No newline at end of file +// Incomplete extmodule domain information. + +firrtl.circuit "IncompleteDomainInfoForExtModule" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in i: !firrtl.uint<1>) + + firrtl.module @IncompleteDomainInfoForExtModule(in %i: !firrtl.uint<1>) { + // expected-error @below {{'firrtl.instance' op missing "ClockDomain" association for port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %i : !firrtl.uint<1> + } +} diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir index 1df48356bc4e..c2cc0d96a83e 100644 --- a/test/Dialect/FIRRTL/infer-domains.mlir +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -230,3 +230,21 @@ firrtl.circuit "InstanceChoiceUpdate" { firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> } } + +// CHECK-LABEL: ConstantInMultipleDomains +firrtl.circuit "ConstantInMultipleDomains" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + + firrtl.module @ConstantInMultipleDomains(in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain) { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %x_A, %x_i = firrtl.instance x @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %x_A, %A + firrtl.matchingconnect %x_i, %c0_ui1 : !firrtl.uint<1> + + %y_A, %y_i = firrtl.instance y @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %y_A, %B + firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> + } +} From e03cb73a1fe6eb8ede20092e67de598a44c81984 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 17 Oct 2025 12:42:35 -0400 Subject: [PATCH 18/20] More domains --- .../FIRRTL/Transforms/InferDomains.cpp | 135 +++++++----------- test/Dialect/FIRRTL/infer-domains-errors.mlir | 31 +++- 2 files changed, 82 insertions(+), 84 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index fa1054ea74a3..a81d9e9bf79d 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -415,13 +415,6 @@ class InferModuleDomains { /// small vector. SmallVector copyPortDomainAssociations(ArrayAttr, size_t); - /// For a concrete domain value, get the unique aliasing port. When a hardware - /// port is associated to a domain, we must ensure that the domain is - /// available as a port of the module. Fails if the domain is not - /// exported by a port, or if the domain is exported by multiple ports. - LogicalResult getExportingPortIndex(FModuleOp, Value, size_t &); - LogicalResult getExportingPort(FModuleOp, Value, BlockArgument &); - /// Add domain ports for any uninferred domains associated to hardware. /// Returns the inserted ports, which will be used later to generalize the /// instances of this module. @@ -478,7 +471,8 @@ class InferModuleDomains { void render(Diagnostic &, Term *) const; void render(Diagnostic &, VariableIDTable &, Term *) const; - void emitPortDomainCrossingError(BlockArgument, size_t, Term *, Term *) const; + void emitPortDomainCrossingError(StringAttr, Location, DomainTypeID, Term *, + Term *) const; /// Emit an error when we fail to infer the concrete domain to drive to a /// domain port. @@ -569,7 +563,10 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { auto *term = getTermForDomain(domainValue); auto &slot = elements[domainTypeID]; if (failed(unify(slot, term))) { - emitPortDomainCrossingError(port, domainTypeID, slot, term); + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + emitPortDomainCrossingError(portName, portLoc, domainTypeID, slot, + term); return failure(); } elements[domainTypeID] = term; @@ -738,7 +735,7 @@ LogicalResult InferModuleDomains::processInstancePorts(T op) { continue; auto domainDecl = circuitInfo.getDomain(domainTypeID); auto domainName = domainDecl.getNameAttr(); - auto portName = op.getPortName(i); + auto portName = op.getPortNameAttr(i); op->emitOpError() << "missing " << domainName << " association for port " << portName; return failure(); @@ -765,18 +762,14 @@ LogicalResult InferModuleDomains::updateModule(FModuleOp op) { void InferModuleDomains::initializeExportTable(FModuleOp module) { size_t numPorts = module.getNumPorts(); - auto directions = module.getPortDirections(); for (size_t i = 0; i < numPorts; ++i) { auto port = module.getArgument(i); auto type = port.getType(); if (!isa(type)) continue; - auto direction = direction::get(directions[i]); - if (direction == Direction::Out) { - auto value = getUnderlyingDomain(port); - if (value) - exportTable[value].push_back(port); - } + auto value = getUnderlyingDomain(port); + if (value) + exportTable[value].push_back(port); } } @@ -838,34 +831,39 @@ InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { if (associations[domainTypeID]) continue; + auto domain = cast(find(row->elements[domainTypeID]))->value; - size_t domainPortIndex = 0; - if (auto arg = dyn_cast(domain)) { - if (arg.getOwner()->getParentOp() == module) { - domainPortIndex = arg.getArgNumber(); - } - } else { - auto exports = exportTable.lookup(domain); - if (exports.empty()) { - auto diag = - module.emitOpError("Failed to infer domain information"); - diag.attachNote(module.getPortLocation(i)) << "for port # " << i; - diag.attachNote() << "the domain is not exported"; - return failure(); - } + auto &exports = exportTable[domain]; + if (exports.empty()) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "private " << domainName << " association for port " + << portName; + diag.attachNote(domain.getLoc()) << "associated domain: " << domain; + return failure(); + } - if (exports.size() > 1) { - auto diag = - module.emitOpError("Failed to infer domain information"); - diag.attachNote(module.getPortLocation(i)) << "for port # " << i; - diag.attachNote() << "cannot choose between aliasing ports"; - for (auto arg : exports) { - diag.attachNote(module.getPortLocation(arg.getArgNumber())) - << "aliased here"; - } - return failure(); + if (exports.size() > 1) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "ambiguous " << domainName << " association for port " + << portName; + for (auto arg : exports) { + auto name = module.getPortNameAttr(arg.getArgNumber()); + auto loc = module.getPortLocation(arg.getArgNumber()); + diag.attachNote(loc) << "candidate association " << name; } + return failure(); } + + auto argument = cast(exports[0]); + auto domainPortIndex = argument.getArgNumber(); associations[domainTypeID] = IntegerAttr::get( IntegerType::get(context, 32, IntegerType::Unsigned), domainPortIndex); @@ -903,37 +901,6 @@ InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, return result; } -LogicalResult InferModuleDomains::getExportingPortIndex(FModuleOp module, - Value value, - size_t &result) { - BlockArgument arg; - if (failed(getExportingPort(module, value, arg))) - return failure(); - - result = arg.getArgNumber(); - return success(); -} - -LogicalResult InferModuleDomains::getExportingPort(FModuleOp module, - Value value, - BlockArgument &result) { - if (auto arg = dyn_cast(value)) { - if (arg.getOwner()->getParentOp() == module) { - result = arg; - return success(); - } - } - - auto exports = exportTable.lookup(value); - assert(!exports.empty()); - - if (exports.size() > 1) - return failure(); - - result = exports[0]; - return success(); -} - void InferModuleDomains::generalizeModule(FModuleOp module) { PortInsertions insertions; // If the port is hardware, we have to check the associated row of @@ -1023,13 +990,15 @@ void InferModuleDomains::generalizeModule(FModuleOp module) { // Put the domain ports in place. module.insertPorts(insertions); - // Solve the variables. + // Solve the variables and record them as "self-exporting". for (auto [var, portIndex] : pendingSolutions) { - auto *solution = allocate(module.getArgument(portIndex)); + auto port = module.getArgument(portIndex); + auto *solution = allocate(port); solve(var, solution); + exportTable[port].push_back(port); } - // Drive the exports. + // Drive the pending exports. auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); for (auto [value, portIndex] : pendingExports) { auto port = module.getArgument(portIndex); @@ -1298,18 +1267,18 @@ void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, } } -void InferModuleDomains::emitPortDomainCrossingError(BlockArgument port, +void InferModuleDomains::emitPortDomainCrossingError(StringAttr portName, + Location portLoc, size_t domainTypeID, Term *term1, Term *term2) const { VariableIDTable idTable; - auto portIndex = port.getArgNumber(); - auto domainOp = globals.circuitInfo.getDomain(domainTypeID); - auto domainName = domainOp.getName(); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); - auto diag = emitError(port.getLoc()); - diag << "illegal " << domainName << " crossing in port #" << portIndex; + auto diag = emitError(portLoc); + diag << "illegal " << domainName << " crossing in port " << portName; auto ¬e1 = diag.attachNote(); note1 << "1st instance: "; @@ -1322,7 +1291,7 @@ void InferModuleDomains::emitPortDomainCrossingError(BlockArgument port, template void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { - auto name = op.getPortName(i); + auto name = op.getPortNameAttr(i); auto diag = emitError(op->getLoc()); auto info = op.getDomainInfo(); diag << "unable to infer value for domain port " << name; @@ -1330,7 +1299,7 @@ void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { if (auto assocs = dyn_cast(info[j])) { for (auto assoc : assocs) { if (i == cast(assoc).getValue()) { - auto name = op.getPortName(j); + auto name = op.getPortNameAttr(j); auto loc = op.getPortLocation(j); diag.attachNote(loc) << "associated with hardware port " << name; break; diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir index 5eb07891753f..da419dcac465 100644 --- a/test/Dialect/FIRRTL/infer-domains-errors.mlir +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -6,7 +6,7 @@ firrtl.circuit "DomainCrossOnPort" { firrtl.module @DomainCrossOnPort( in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain, - // expected-error @below {{illegal ClockDomain crossing in port #2}} + // expected-error @below {{illegal "ClockDomain" crossing in port "p"}} // expected-note @below {{1st instance: A}} // expected-note @below {{2nd instance: B}} in %p: !firrtl.uint<1> domains [%A, %B] @@ -97,3 +97,32 @@ firrtl.circuit "IncompleteDomainInfoForExtModule" { firrtl.matchingconnect %foo_i, %i : !firrtl.uint<1> } } + +// ----- + +// Domain not exported like it should be. + +// ----- + +// Domain exported multiple times. Which do we choose? + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "DI"}} + in %DI : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "DO"}} + out %DO : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%DO], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + // DI and DO are aliases + firrtl.domain.define %DO, %DI + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + From c920473c565ec18d6ef73592d215c7655372891c Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 17 Oct 2025 13:05:38 -0400 Subject: [PATCH 19/20] More cleanup --- .../FIRRTL/Transforms/InferDomains.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index a81d9e9bf79d..2887b0997ba4 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -471,13 +471,14 @@ class InferModuleDomains { void render(Diagnostic &, Term *) const; void render(Diagnostic &, VariableIDTable &, Term *) const; - void emitPortDomainCrossingError(StringAttr, Location, DomainTypeID, Term *, + template + void emitPortDomainCrossingError(T, size_t, DomainTypeID, Term *, Term *) const; /// Emit an error when we fail to infer the concrete domain to drive to a /// domain port. template - void emitDomainPortInferenceError(T op, size_t i) const; + void emitDomainPortInferenceError(T, size_t) const; /// Information about the domains in a circuit. GlobalState &globals; @@ -563,10 +564,7 @@ LogicalResult InferModuleDomains::processPorts(FModuleOp module) { auto *term = getTermForDomain(domainValue); auto &slot = elements[domainTypeID]; if (failed(unify(slot, term))) { - auto portName = module.getPortNameAttr(i); - auto portLoc = module.getPortLocation(i); - emitPortDomainCrossingError(portName, portLoc, domainTypeID, slot, - term); + emitPortDomainCrossingError(module, i, domainTypeID, slot, term); return failure(); } elements[domainTypeID] = term; @@ -1267,13 +1265,15 @@ void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, } } -void InferModuleDomains::emitPortDomainCrossingError(StringAttr portName, - Location portLoc, +template +void InferModuleDomains::emitPortDomainCrossingError(T op, size_t i, size_t domainTypeID, Term *term1, Term *term2) const { VariableIDTable idTable; + auto portName = op.getPortNameAttr(i); + auto portLoc = op.getPortLocation(i); auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); auto domainName = domainDecl.getNameAttr(); @@ -1282,11 +1282,11 @@ void InferModuleDomains::emitPortDomainCrossingError(StringAttr portName, auto ¬e1 = diag.attachNote(); note1 << "1st instance: "; - render(note1, term1); + render(note1, idTable, term1); auto ¬e2 = diag.attachNote(); note2 << "2nd instance: "; - render(note2, term2); + render(note2, idTable, term2); } template From d5a833c4495a01ec4fb2dbf19c3d94b8d9a53d45 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Fri, 17 Oct 2025 16:31:36 -0400 Subject: [PATCH 20/20] Clean up --- lib/Dialect/FIRRTL/Transforms/InferDomains.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp index 2887b0997ba4..c989fef5b224 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -1161,15 +1161,18 @@ RowTerm *InferModuleDomains::getDomainAssociationAsRow(Value value) { } // If the term is already a row, return it. - auto *row = dyn_cast(term); - if (row) + if (auto *row = dyn_cast(term)) return row; // Otherwise, unify the term with a fresh row of domains. - row = allocateRow(); - auto result = unify(row, term); - assert(result.succeeded()); - return row; + if (auto *var = dyn_cast(term)) { + auto *row = allocateRow(); + solve(var, row); + return row; + } + + assert(false && "unhandled term type"); + return nullptr; } Term *InferModuleDomains::getDomainAssociation(Value value) {