diff --git a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td index b1f0ece66118..42705051d7bf 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td @@ -119,17 +119,27 @@ def InstanceOp : HardwareDeclOp<"instance", [ CArg<"hw::InnerSymAttr", "hw::InnerSymAttr()">:$innerSym)>]; let extraClassDeclaration = [{ + /// Return the number of ports for this instance. + size_t getNumPorts() { + return getNumResults(); + } + /// Return the port direction for the specified result number. Direction getPortDirection(size_t resultNo) { return direction::get(getPortDirections()[resultNo]); } /// Return the port name for the specified result number. - StringAttr getPortName(size_t resultNo) { + StringAttr getPortNameAttr(size_t resultNo) { return cast(getPortNames()[resultNo]); } - StringRef getPortNameStr(size_t resultNo) { - return getPortName(resultNo).getValue(); + + StringRef getPortName(size_t resultNo) { + return getPortNameAttr(resultNo).getValue(); + } + + Location getPortLocation(size_t) { + return getLoc(); } /// Hooks for port annotations. @@ -208,15 +218,29 @@ def InstanceChoiceOp : HardwareDeclOp<"instance_choice", [ let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ + /// Return the number of ports for this instance. + size_t getNumPorts() { + return getNumResults(); + } + /// Return the port direction for the specified result number. Direction getPortDirection(size_t resultNo) { return direction::get(getPortDirections()[resultNo]); } /// Return the port name for the specified result number. - StringAttr getPortName(size_t resultNo) { + StringAttr getPortNameAttr(size_t resultNo) { return cast(getPortNames()[resultNo]); } + + StringRef getPortName(size_t resultNo) { + return getPortNameAttr(resultNo).getValue(); + } + + Location getPortLocation(size_t resultNo) { + return getLoc(); + } + /// Return the default target attribute. FlatSymbolRefAttr getDefaultTargetAttr() { return llvm::cast(getModuleNamesAttr()[0]); @@ -313,9 +337,10 @@ def MemOp : HardwareDeclOp<"mem", [DeclareOpInterfaceMethods]> { size_t getMaskBits(); /// Return the port name for the specified result number. - StringAttr getPortName(size_t resultNo); - StringRef getPortNameStr(size_t resultNo) { - return getPortName(resultNo).getValue(); + StringAttr getPortNameAttr(size_t resultNo); + + StringRef getPortName(size_t resultNo) { + return getPortNameAttr(resultNo).getValue(); } /// Return the port type for the specified result number. diff --git a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h index e8339dd7320b..44a9903b5520 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.h @@ -85,6 +85,28 @@ struct PortInfo { annotations(annos), domains(domains) {} }; +inline bool operator==(const PortInfo &lhs, const PortInfo &rhs) { + if (lhs.name != rhs.name) + return false; + if (lhs.type != rhs.type) + return false; + if (lhs.direction != rhs.direction) + return false; + if (lhs.sym != rhs.sym) + return false; + if (lhs.loc != rhs.loc) + return false; + if (lhs.annotations != rhs.annotations) + return false; + if (lhs.domains != rhs.domains) + return false; + return true; +} + +inline bool operator!=(const PortInfo &lhs, const PortInfo &rhs) { + return !(lhs == rhs); +} + enum class ConnectBehaviorKind { /// Classic FIRRTL connections: last connect 'wins' across paths; /// conditionally applied under 'when'. diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index 15cf686d618c..caf8eeb7dffb 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -909,6 +909,19 @@ def CheckLayers : Pass<"firrtl-check-layers", "firrtl::CircuitOp"> { }]; } +def InferDomains : Pass<"firrtl-infer-domains", "firrtl::CircuitOp"> { + let summary = "Infer and type check all firrtl domains"; + let description = [{ + This pass does domain inference on a FIRRTL circuit. The end result of this + is either a corrrctly domain-checked FIRRTL circuit or failure with verbose + error messages indicating why the FIRRTL circuit has illegal domain + constructs. + + E.g., this pass can be used to check for illegal clock-domain-crossings if + clock domains are specified for signals in the design. + }]; +} + def LowerDomains : Pass<"firrtl-lower-domains", "firrtl::CircuitOp"> { let summary = "lower domain information to properties"; let description = [{ diff --git a/include/circt/Firtool/Firtool.h b/include/circt/Firtool/Firtool.h index 28e5f6d5d897..5069dca326d2 100644 --- a/include/circt/Firtool/Firtool.h +++ b/include/circt/Firtool/Firtool.h @@ -144,6 +144,8 @@ class FirtoolOptions { bool getEmitAllBindFiles() const { return emitAllBindFiles; } + bool shouldInferDomains() const { return inferDomains; } + // Setters, used by the CAPI FirtoolOptions &setOutputFilename(StringRef name) { outputFilename = name; @@ -393,6 +395,11 @@ class FirtoolOptions { return *this; } + FirtoolOptions &setInferDomains(bool value) { + inferDomains = value; + return *this; + } + private: std::string outputFilename; @@ -447,6 +454,7 @@ class FirtoolOptions { bool lintStaticAsserts; bool lintXmrsInDesign; bool emitAllBindFiles; + bool inferDomains; }; void registerFirtoolCLOptions(); diff --git a/include/circt/Support/InstanceGraph.h b/include/circt/Support/InstanceGraph.h index a26cf5ddce77..70fdcf2e67f9 100644 --- a/include/circt/Support/InstanceGraph.h +++ b/include/circt/Support/InstanceGraph.h @@ -59,6 +59,11 @@ class InstanceGraphNode; class InstanceRecord : public llvm::ilist_node_with_parent { public: + /// Get the op that this is tracking. + Operation *getOperation() { + return instance.getOperation(); + } + /// Get the instance-like op that this is tracking. template auto getInstance() { @@ -113,6 +118,8 @@ class InstanceGraphNode : public llvm::ilist_node { public: InstanceGraphNode() : module(nullptr) {} + Operation *getOperation() { return module.getOperation(); } + /// Get the module that this node is tracking. template auto getModule() { diff --git a/lib/Dialect/FIRRTL/Export/FIREmitter.cpp b/lib/Dialect/FIRRTL/Export/FIREmitter.cpp index e9fb68ad408a..324f00e5eb2a 100644 --- a/lib/Dialect/FIRRTL/Export/FIREmitter.cpp +++ b/lib/Dialect/FIRRTL/Export/FIREmitter.cpp @@ -1126,7 +1126,7 @@ void Emitter::emitStatement(InstanceOp op) { portName.push_back('.'); unsigned baseLen = portName.size(); for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - portName.append(legalize(op.getPortName(i))); + portName.append(legalize(op.getPortNameAttr(i))); addValueName(op.getResult(i), portName); portName.resize(baseLen); } @@ -1153,7 +1153,7 @@ void Emitter::emitStatement(InstanceChoiceOp op) { portName.push_back('.'); unsigned baseLen = portName.size(); for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - portName.append(legalize(op.getPortName(i))); + portName.append(legalize(op.getPortNameAttr(i))); addValueName(op.getResult(i), portName); portName.resize(baseLen); } diff --git a/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp b/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp index d31008b383ca..26eb0506c607 100644 --- a/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp @@ -29,7 +29,7 @@ using llvm::StringRef; static LogicalResult updateExpandedPort(StringRef field, AnnoTarget &ref) { if (auto mem = dyn_cast(ref.getOp())) for (size_t p = 0, pe = mem.getPortNames().size(); p < pe; ++p) - if (mem.getPortNameStr(p) == field) { + if (mem.getPortName(p) == field) { ref = PortAnnoTarget(mem, p); return success(); } diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index cca9cace386d..d772e391e876 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -2736,7 +2736,7 @@ InstanceOp InstanceOp::cloneWithInsertedPorts( } newPortDirections.push_back(getPortDirection(i)); - newPortNames.push_back(getPortName(i)); + newPortNames.push_back(getPortNameAttr(i)); newPortTypes.push_back(getType(i)); newPortAnnos.push_back(getPortAnnotation(i)); newDomainInfo.push_back(getDomainInfo()[i]); @@ -2887,7 +2887,7 @@ void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { base = "inst"; for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) { - setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str()); + setNameFn(getResult(i), (base + "_" + getPortName(i)).str()); } } @@ -3219,7 +3219,7 @@ InstanceChoiceOp InstanceChoiceOp::cloneWithInsertedPorts( } newPortDirections.push_back(getPortDirection(i)); - newPortNames.push_back(getPortName(i)); + newPortNames.push_back(getPortNameAttr(i)); newPortTypes.push_back(getType(i)); newPortAnnos.push_back(getPortAnnotations()[i]); newDomainInfo.push_back(getDomainInfo()[i]); @@ -3343,7 +3343,7 @@ LogicalResult MemOp::verify() { FIRRTLType oldDataType; for (size_t i = 0, e = getNumResults(); i != e; ++i) { - auto portName = getPortName(i); + auto portName = getPortNameAttr(i); // Get a bundle type representing this port, stripping an outer // flip if it exists. If this is not a bundle<> or @@ -3453,10 +3453,10 @@ LogicalResult MemOp::verify() { // Error if the type of the current port was not the same as the // last port, but skip checking the first port. if (oldDataType && oldDataType != dataType) { - emitOpError() << "port " << getPortName(i) - << " has a different type than port " << getPortName(i - 1) - << " (expected " << oldDataType << ", but got " << dataType - << ")"; + emitOpError() << "port " << getPortNameAttr(i) + << " has a different type than port " + << getPortNameAttr(i - 1) << " (expected " << oldDataType + << ", but got " << dataType << ")"; return failure(); } @@ -3539,7 +3539,7 @@ SmallVector MemOp::getPorts() { for (size_t i = 0, e = getNumResults(); i != e; ++i) { // Each port is a bundle. auto portType = type_cast(getResult(i).getType()); - result.push_back({getPortName(i), getMemPortKindFromType(portType)}); + result.push_back({getPortNameAttr(i), getMemPortKindFromType(portType)}); } return result; } @@ -3596,7 +3596,7 @@ FIRRTLBaseType MemOp::getDataType() { .getElementType(dataFieldName); } -StringAttr MemOp::getPortName(size_t resultNo) { +StringAttr MemOp::getPortNameAttr(size_t resultNo) { return cast(getPortNames()[resultNo]); } @@ -3717,7 +3717,7 @@ void MemOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { base = "mem"; for (size_t i = 0, e = (*this)->getNumResults(); i != e; ++i) { - setNameFn(getResult(i), (base + "_" + getPortNameStr(i)).str()); + setNameFn(getResult(i), (base + "_" + getPortName(i)).str()); } } @@ -4278,6 +4278,14 @@ LogicalResult PropAssignOp::verify() { return success(); } +template +static FlatSymbolRefAttr getDomainTypeNameOfResult(T op, size_t i) { + auto info = op.getDomainInfo(); + if (info.empty()) + return {}; + return dyn_cast(info[i]); +} + static FlatSymbolRefAttr getDomainTypeName(Value value) { if (!isa(value.getType())) return {}; @@ -4297,13 +4305,10 @@ static FlatSymbolRefAttr getDomainTypeName(Value value) { if (auto result = dyn_cast(value)) { auto *op = result.getDefiningOp(); - if (auto instance = dyn_cast(op)) { - auto info = instance.getDomainInfo(); - if (info.empty()) - return {}; - auto attr = info[result.getResultNumber()]; - return dyn_cast(attr); - } + if (auto instance = dyn_cast(op)) + return getDomainTypeNameOfResult(instance, result.getResultNumber()); + if (auto instance = dyn_cast(op)) + return getDomainTypeNameOfResult(instance, result.getResultNumber()); return {}; } diff --git a/lib/Dialect/FIRRTL/FIRRTLReductions.cpp b/lib/Dialect/FIRRTL/FIRRTLReductions.cpp index 8e9dcae73a3d..9aad07765317 100644 --- a/lib/Dialect/FIRRTL/FIRRTLReductions.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLReductions.cpp @@ -414,7 +414,7 @@ struct InstanceStubber : public OpReduction { for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) { auto result = instOp.getResult(i); auto name = builder.getStringAttr(Twine(instOp.getName()) + "_" + - instOp.getPortNameStr(i)); + instOp.getPortName(i)); auto wire = firrtl::WireOp::create(builder, result.getType(), name, firrtl::NameKindEnum::DroppableName, @@ -463,7 +463,7 @@ struct MemoryStubber : public OpReduction { for (unsigned i = 0, e = memOp.getNumResults(); i != e; ++i) { auto result = memOp.getResult(i); auto name = builder.getStringAttr(Twine(memOp.getName()) + "_" + - memOp.getPortNameStr(i)); + memOp.getPortName(i)); auto wire = firrtl::WireOp::create(builder, result.getType(), name, firrtl::NameKindEnum::DroppableName, @@ -1190,7 +1190,7 @@ struct EagerInliner : public OpReduction { for (unsigned i = 0, e = instOp.getNumResults(); i != e; ++i) { auto result = instOp.getResult(i); auto name = rewriter.getStringAttr(Twine(instOp.getName()) + "_" + - instOp.getPortNameStr(i)); + instOp.getPortName(i)); auto wire = WireOp::create(rewriter, instOp.getLoc(), result.getType(), name, NameKindEnum::DroppableName, instOp.getPortAnnotation(i), StringAttr{}) diff --git a/lib/Dialect/FIRRTL/FIRRTLUtils.cpp b/lib/Dialect/FIRRTL/FIRRTLUtils.cpp index 289e1a526cab..0614e6efb98b 100644 --- a/lib/Dialect/FIRRTL/FIRRTLUtils.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLUtils.cpp @@ -571,8 +571,7 @@ static void getDeclName(Value value, SmallString<64> &string, bool nameSafe) { .Case([&](auto op) { string += op.getName(); string += nameSafe ? "_" : "."; - string += op.getPortName(cast(value).getResultNumber()) - .getValue(); + string += op.getPortName(cast(value).getResultNumber()); value = nullptr; }) .Case([&](auto op) { diff --git a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt index f8a1ec9e4045..d24f11c7563b 100755 --- a/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt +++ b/lib/Dialect/FIRRTL/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms GrandCentral.cpp IMConstProp.cpp IMDeadCodeElim.cpp + InferDomains.cpp InferReadWrite.cpp InferResets.cpp InferWidths.cpp diff --git a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp index 81bfa3c64d4c..d57ffd3a4ef5 100644 --- a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp +++ b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp @@ -1596,7 +1596,7 @@ fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, continue; LLVM_DEBUG(llvm::dbgs() << "- Updating instance port \"" << instOp.getInstanceName() - << "." << instOp.getPortName(index).getValue() << "\" from " + << "." << instOp.getPortName(index) << "\" from " << oldType << " to " << newType << "\n"); // If the type changed we transform it back to the old type with an diff --git a/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp b/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp index 72b9421a424d..5d2a166649b7 100644 --- a/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp +++ b/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp @@ -563,7 +563,7 @@ void ExtractInstancesPass::extractInstances() { for (unsigned portIdx = 0; portIdx < numInstPorts; ++portIdx) { // Assemble the new port name as "_", where the prefix is // provided by the extraction annotation. - auto name = inst.getPortNameStr(portIdx); + auto name = inst.getPortName(portIdx); auto nameAttr = StringAttr::get( &getContext(), prefix.empty() ? Twine(name) : Twine(prefix) + "_" + name); @@ -913,7 +913,7 @@ void ExtractInstancesPass::groupInstances() { StringRef prefix(instPrefixNamesPair[inst].first); unsigned portNum = inst.getNumResults(); for (unsigned portIdx = 0; portIdx < portNum; ++portIdx) { - auto name = inst.getPortNameStr(portIdx); + auto name = inst.getPortName(portIdx); auto nameAttr = builder.getStringAttr( prefix.empty() ? Twine(name) : Twine(prefix) + "_" + name); PortInfo port{nameAttr, diff --git a/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp b/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp index 2da67137e816..933c0c2e39db 100644 --- a/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp +++ b/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp @@ -138,7 +138,7 @@ struct FlattenMemoryPass auto result = memOp.getResult(index); auto wire = WireOp::create(builder, result.getType(), (memOp.getName() + "_" + - memOp.getPortName(index).getValue()) + memOp.getPortName(index)) .str()) .getResult(); result.replaceAllUsesWith(wire); diff --git a/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp new file mode 100644 index 000000000000..c989fef5b224 --- /dev/null +++ b/lib/Dialect/FIRRTL/Transforms/InferDomains.cpp @@ -0,0 +1,1346 @@ +//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements FIRRTL domain inference and checking with canonical +// domain representation. Domain sequences are canonicalized by sorting and +// removing duplicates, making domain order irrelevant and allowing duplicate +// domains to be treated as equivalent. The result of this pass is either a +// correctly domain-inferred circuit or pass failure if the circuit contains +// illegal domain crossings. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" +#include "circt/Dialect/FIRRTL/FIRRTLOps.h" +#include "circt/Dialect/FIRRTL/FIRRTLUtils.h" +#include "circt/Dialect/FIRRTL/Passes.h" +#include "circt/Support/Debug.h" +#include "circt/Support/Namespace.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "firrtl-infer-domains" +#undef NDEBUG + +namespace circt { +namespace firrtl { +#define GEN_PASS_DEF_INFERDOMAINS +#include "circt/Dialect/FIRRTL/Passes.h.inc" +} // namespace firrtl +} // namespace circt + +using namespace circt; +using namespace firrtl; + +using InstanceIterator = InstanceGraphNode::UseIterator; +using InstanceRange = llvm::iterator_range; +using PortInsertions = SmallVector>; + +//====-------------------------------------------------------------------------- +// Helpers for working with module or instance domain info. +//====-------------------------------------------------------------------------- + +/// From a domain info attribute, get the domain-type of a domain value at +/// index i. +static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) { + if (info.empty()) + return nullptr; + auto ref = cast(info[i]); + return ref.getAttr(); +} + +/// From a domain info attribute, get the row of associated domains for a +/// hardware value at index i. +static auto getPortDomainAssociation(ArrayAttr info, size_t i) { + if (info.empty()) + return info.getAsRange(); + return cast(info[i]).getAsRange(); +} + +/// Return true if the value is a port on the module. +static bool isPort(FModuleOp module, BlockArgument arg) { + return arg.getOwner()->getParentOp() == module; +} + +/// Return true if the value is a port on the module. +static bool isPort(FModuleOp module, Value value) { + auto arg = dyn_cast(value); + if (!arg) + return false; + return isPort(module, arg); +} + +//====-------------------------------------------------------------------------- +// Circuit-wide state. +//====-------------------------------------------------------------------------- + +/// Each declared domain in the circuit is assigned an index, based on the order +/// in which it appears. Domain associations for hardware values are represented +/// as a list of domains, sorted by the index of the domain type. +using DomainTypeID = size_t; + +/// Information about the domains in the circuit. Able to map domains to their +/// type ID, which in this pass is the canonical way to reference the type +/// of a domain. +namespace { +struct CircuitDomainInfo { + CircuitDomainInfo(CircuitOp circuit) { processCircuit(circuit); } + + ArrayRef getDomains() const { return domainTable; } + size_t getNumDomains() const { return domainTable.size(); } + DomainOp getDomain(DomainTypeID id) const { return domainTable[id]; } + + DomainTypeID getDomainTypeID(DomainOp op) const { + return typeIDTable.at(op.getNameAttr()); + } + + DomainTypeID getDomainTypeID(StringAttr name) const { + return typeIDTable.at(name); + } + + DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const { + return getDomainTypeID(ref.getAttr()); + } + + DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const { + auto name = getDomainPortTypeName(info, i); + return getDomainTypeID(name); + } + + DomainTypeID getDomainTypeID(Value value) const { + assert(isa(value.getType())); + if (auto arg = dyn_cast(value)) { + auto *block = arg.getOwner(); + auto *owner = block->getParentOp(); + auto module = cast(owner); + auto info = module.getDomainInfoAttr(); + auto i = arg.getArgNumber(); + return getDomainTypeID(info, i); + } + + auto result = dyn_cast(value); + auto *owner = result.getOwner(); + auto instance = cast(owner); + auto info = instance.getDomainInfoAttr(); + auto i = result.getResultNumber(); + return getDomainTypeID(info, i); + } + +private: + void processDomain(DomainOp op) { + auto index = domainTable.size(); + auto name = op.getNameAttr(); + domainTable.push_back(op); + typeIDTable.insert({name, index}); + } + + void processCircuit(CircuitOp circuit) { + for (auto decl : circuit.getOps()) + processDomain(decl); + } + + /// A map from domain type ID to op. + SmallVector domainTable; + + /// A map from domain name to type ID. + DenseMap typeIDTable; +}; + +/// Information about the changes made to the interface of a module, which can +/// be replayed onto an instance. +struct ModuleUpdateInfo { + /// The updated domain information for a module. + ArrayAttr portDomainInfo; + /// The domain ports which have been inserted into a module. + PortInsertions portInsertions; +}; + +struct GlobalState { + GlobalState(CircuitOp circuit) : circuitInfo(circuit) {} + + CircuitDomainInfo circuitInfo; + DenseMap moduleUpdateTable; +}; + +} // namespace + +//====-------------------------------------------------------------------------- +// Terms: Syntax for unifying domain and domain-rows. +//====-------------------------------------------------------------------------- + +namespace { + +/// The different sorts of terms in the unification engine. +enum class TermKind { + Variable, + Value, + Row, +}; + +/// A term in the unification engine. +struct Term { + constexpr Term(TermKind kind) : kind(kind) {} + TermKind kind; +}; + +/// Helper to define a term kind. +template +struct TermBase : Term { + static bool classof(const Term *term) { return term->kind == K; } + TermBase() : Term(K) {} +}; + +/// An unknown value. +struct VariableTerm : public TermBase { + VariableTerm() : leader(nullptr) {} + VariableTerm(Term *leader) : leader(leader) {} + Term *leader; +}; + +/// A concrete value defined in the IR. +struct ValueTerm : public TermBase { + ValueTerm(Value value) : value(value) {} + Value getValue() const { return value; } + Value value; +}; + +/// A row of domains. +struct RowTerm : public TermBase { + RowTerm(ArrayRef elements) : elements(elements) {} + ArrayRef elements; +}; + +/// A helper for assigning low numeric IDs to variables for user-facing output. +struct VariableIDTable { + size_t get(VariableTerm *term) { + auto [it, inserted] = table.insert({term, table.size() + 1}); + return it->second; + } + + DenseMap table; +}; + +#ifndef NDEBUG + +raw_ostream &dump(llvm::raw_ostream &out, const Term *term); + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const VariableTerm *term) { + return out << "var@" << (void *)term << "{leader=" << term->leader << "}"; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const ValueTerm *term) { + return out << "val@" << term << "{" << term->value << "}"; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const RowTerm *term) { + out << "row@" << term << "{"; + bool first = true; + for (auto *element : term->elements) { + if (!first) + out << ", "; + dump(out, element); + first = false; + } + out << "}"; + return out; +} + +// NOLINTNEXTLINE(misc-no-recursion) +raw_ostream &dump(raw_ostream &out, const Term *term) { + if (!term) + return out << "null"; + if (auto *var = dyn_cast(term)) + return dump(out, var); + if (auto *val = dyn_cast(term)) + return dump(out, val); + if (auto *row = dyn_cast(term)) + return dump(out, row); + llvm_unreachable("unknown term"); +} +#endif // DEBUG + +// NOLINTNEXTLINE(misc-no-recursion) +Term *find(Term *x) { + if (!x) + return nullptr; + + if (auto *var = dyn_cast(x)) { + if (var->leader == nullptr) + return var; + + auto *leader = find(var->leader); + if (leader != var->leader) + var->leader = leader; + return leader; + } + + return x; +} + +LogicalResult unify(Term *lhs, Term *rhs); + +LogicalResult unify(VariableTerm *x, Term *y) { + x->leader = y; + return success(); +} + +LogicalResult unify(ValueTerm *xv, Term *y) { + if (auto *yv = dyn_cast(y)) { + yv->leader = xv; + return success(); + } + if (auto *yv = dyn_cast(y)) { + return success(xv == yv); + } + return failure(); +} + +// NOLINTNEXTLINE(misc-no-recursion) +LogicalResult unify(RowTerm *lhsRow, Term *rhs) { + if (auto *rhsVar = dyn_cast(rhs)) { + rhsVar->leader = lhsRow; + return success(); + } + if (auto *rhsRow = dyn_cast(rhs)) { + assert(lhsRow->elements.size() == rhsRow->elements.size()); + for (auto [x, y] : llvm::zip(lhsRow->elements, rhsRow->elements)) { + if (failed(unify(x, y))) + return failure(); + } + return success(); + } + + return failure(); +} + +// NOLINTNEXTLINE(misc-no-recursion) +LogicalResult unify(Term *lhs, Term *rhs) { + LLVM_DEBUG(auto &out = llvm::errs(); out << "unify x="; dump(out, lhs); + out << " y="; dump(out, rhs); out << "\n";); + if (!lhs || !rhs) + return success(); + lhs = find(lhs); + rhs = find(rhs); + if (lhs == rhs) + return success(); + if (auto *lhsVar = dyn_cast(lhs)) + return unify(lhsVar, rhs); + if (auto *lhsVal = dyn_cast(lhs)) + return unify(lhsVal, rhs); + if (auto *lhsRow = dyn_cast(lhs)) + return unify(lhsRow, rhs); + return failure(); +} + +void solve(Term *lhs, Term *rhs) { + auto result = unify(lhs, rhs); + (void)result; + assert(result.succeeded()); +} + +} // namespace + +//====-------------------------------------------------------------------------- +// InferModuleDomains: Primary workhorse for inferring domains on modules. +//====-------------------------------------------------------------------------- + +namespace { +class InferModuleDomains { +public: + /// Run infer-domains on a module. + static LogicalResult run(GlobalState &, FModuleOp); + +private: + /// Initialize module-level state. + InferModuleDomains(GlobalState &); + + /// Execute on the given module. + LogicalResult operator()(FModuleOp); + + /// Record the domain associations of hardware ports, and record the + /// underlying value of output domain ports. + LogicalResult processPorts(FModuleOp); + + /// Record the domain associations of hardware, and record the underlying + /// value of domains, defined within the body of the module. + LogicalResult processBody(FModuleOp); + + /// Record the domain associations of any operands or results, updating the op + /// if necessary. + LogicalResult processOp(Operation *); + LogicalResult processOp(InstanceOp); + LogicalResult processOp(InstanceChoiceOp); + LogicalResult processOp(UnsafeDomainCastOp); + LogicalResult processOp(DomainDefineOp); + + /// Apply the port changes of a module onto an instance-like op. + template + T updateInstancePorts(T op, const ModuleUpdateInfo &update); + + /// Record the domain associations of the ports of an instance-like op. + template + LogicalResult processInstancePorts(T op); + + LogicalResult updateModule(FModuleOp); + + /// Build a table of exported domains: a map from domains defined internally, + /// to their set of aliasing output ports. + void initializeExportTable(FModuleOp); + + /// After generalizing the module, all domains should be solved. Reflect the + /// solved domain associations into the port domain info attribute. + LogicalResult updatePortDomainAssociations(FModuleOp); + + /// After updating the port domain associations, walk the body of the module + /// to fix up any child instance modules. + LogicalResult updateDomainAssociationsInBody(FModuleOp); + LogicalResult updateOpDomainAssociations(Operation *); + + template + LogicalResult updateInstanceDomainAssociations(T op); + + /// Copy the domain associations from the module domain info attribute into a + /// small vector. + SmallVector copyPortDomainAssociations(ArrayAttr, size_t); + + /// Add domain ports for any uninferred domains associated to hardware. + /// Returns the inserted ports, which will be used later to generalize the + /// instances of this module. + void generalizeModule(FModuleOp); + + void generalizeInstance(InstanceOp, const PortInsertions &); + + /// Unify the associated domain rows of two terms. + LogicalResult unifyAssociations(Operation *, Value, Value); + + /// If the domain value is an alias, returns the domain it aliases. + Value getUnderlyingDomain(Value); + + /// Record a mapping from domain in the IR to its corresponding term. + void setTermForDomain(Value, Term *); + + /// Get the corresponding term for a domain in the IR. + Term *getTermForDomain(Value); + + /// Get the corresponding term for a domain in the IR, or null if unset. + Term *getOptTermForDomain(Value) const; + + /// Record a mapping from a hardware value in the IR to a term which + /// represents the row of domains it is associated with. + void setDomainAssociation(Value, Term *); + + /// Get the associated domain row, forced to be at least a row. + RowTerm *getDomainAssociationAsRow(Value); + + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, allocate a variable to stand for + /// the row of domains. + Term *getDomainAssociation(Value); + + /// For a hardware value, get the term which represents the row of associated + /// domains. If no mapping has been defined, returns nullptr. + Term *getOptDomainAssociation(Value) const; + + /// Allocate a row, where each domain is a variable. + RowTerm *allocateRow(); + + /// Allocate a row. + RowTerm *allocateRow(ArrayRef); + + /// Allocate a term. + template + T *allocate(Args &&...); + + /// Allocate an array of terms. If any terms were left null, automatically + /// replace them with a new variable. + ArrayRef allocateArray(ArrayRef); + + /// Print a term in a user-friendly way. + void render(Diagnostic &, Term *) const; + void render(Diagnostic &, VariableIDTable &, Term *) const; + + template + void emitPortDomainCrossingError(T, size_t, DomainTypeID, Term *, + Term *) const; + + /// Emit an error when we fail to infer the concrete domain to drive to a + /// domain port. + template + void emitDomainPortInferenceError(T, size_t) const; + + /// Information about the domains in a circuit. + GlobalState &globals; + + /// Term allocator. + llvm::BumpPtrAllocator allocator; + + /// Map from domains in the IR to their underlying term. + DenseMap termTable; + + /// A map from hardware values to their associated row of domains, as a term. + DenseMap associationTable; + + /// A map from local domain definition to its aliasing output ports. + DenseMap> exportTable; + + /// A boolean tracking if a non-fatal error occurred, or not. + bool ok = true; +}; +} // namespace + +LogicalResult InferModuleDomains::run(GlobalState &globals, FModuleOp module) { + return InferModuleDomains(globals)(module); +} + +InferModuleDomains::InferModuleDomains(GlobalState &globals) + : globals(globals) {} + +LogicalResult InferModuleDomains::operator()(FModuleOp module) { + LLVM_DEBUG( + llvm::errs() << "================================================\n"; + llvm::errs() << "infer module domains: " << module.getModuleName() + << "\n"; + llvm::errs() << "================================================\n";); + + if (failed(processPorts(module))) + return failure(); + + if (failed(processBody(module))) + return failure(); + + LLVM_DEBUG(for (auto association : associationTable) { + llvm::errs() << "association:\n"; + llvm::errs() << " " << association.first << "\n"; + llvm::errs() << " " << association.second << "\n"; + }); + + if (failed(updateModule(module))) + return failure(); + + return llvm::success(ok); +} + +LogicalResult InferModuleDomains::processPorts(FModuleOp module) { + auto portDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + + // Process module ports - domain ports define explicit domains. + DenseMap domainTypeIDTable; + for (size_t i = 0; i < numPorts; ++i) { + BlockArgument port = module.getArgument(i); + + // This is a domain port. + if (isa(port.getType())) { + auto typeID = globals.circuitInfo.getDomainTypeID(portDomainInfo, i); + domainTypeIDTable[i] = typeID; + if (module.getPortDirection(i) == Direction::In) { + setTermForDomain(port, allocate(port)); + } + continue; + } + + // This is a port, which may have explicit domain information. + auto portDomains = getPortDomainAssociation(portDomainInfo, i); + if (portDomains.empty()) + continue; + + SmallVector elements(globals.circuitInfo.getNumDomains()); + for (auto domainPortIndexAttr : portDomains) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = domainTypeIDTable[domainPortIndex]; + auto domainValue = module.getArgument(domainPortIndex); + auto *term = getTermForDomain(domainValue); + auto &slot = elements[domainTypeID]; + if (failed(unify(slot, term))) { + emitPortDomainCrossingError(module, i, domainTypeID, slot, term); + return failure(); + } + elements[domainTypeID] = term; + } + auto *row = allocateRow(elements); + setDomainAssociation(port, row); + } + + return success(); +} + +LogicalResult InferModuleDomains::processBody(FModuleOp module) { + LogicalResult result = success(); + module.getBody().walk([&](Operation *op) -> WalkResult { + if (failed(processOp(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +LogicalResult InferModuleDomains::processOp(Operation *op) { + LLVM_DEBUG(llvm::errs() << "process op: " << *op << "\n"); + + if (auto instance = dyn_cast(op)) + return processOp(instance); + if (auto instance = dyn_cast(op)) + return processOp(instance); + if (auto cast = dyn_cast(op)) + return processOp(cast); + if (auto def = dyn_cast(op)) + return processOp(def); + + // For all other operations (including connections), propagate domains from + // operands to results. This is a conservative approach - all operands and + // results share the same domain associations. + Value lhs; + for (auto rhs : op->getOperands()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(op, lhs, rhs))) + return failure(); + lhs = rhs; + } + for (auto rhs : op->getResults()) { + if (!isa(rhs.getType())) + continue; + if (auto *op = rhs.getDefiningOp(); + op && op->hasTrait()) + continue; + if (failed(unifyAssociations(op, lhs, rhs))) + return failure(); + lhs = rhs; + } + return success(); +} + +LogicalResult InferModuleDomains::processOp(InstanceOp op) { + auto module = op.getReferencedModuleNameAttr(); + auto lookup = globals.moduleUpdateTable.find(module); + if (lookup != globals.moduleUpdateTable.end()) + op = updateInstancePorts(op, lookup->second); + return processInstancePorts(op); +} + +LogicalResult InferModuleDomains::processOp(InstanceChoiceOp op) { + auto module = op.getDefaultTargetAttr().getAttr(); + auto lookup = globals.moduleUpdateTable.find(module); + if (lookup != globals.moduleUpdateTable.end()) + op = updateInstancePorts(op, lookup->second); + return processInstancePorts(op); +} + +LogicalResult InferModuleDomains::processOp(UnsafeDomainCastOp op) { + auto domains = op.getDomains(); + if (domains.empty()) + return unifyAssociations(op, op.getInput(), op.getResult()); + + auto input = op.getInput(); + RowTerm *inputRow = getDomainAssociationAsRow(input); + SmallVector elements(inputRow->elements); + for (auto domain : op.getDomains()) { + auto typeID = globals.circuitInfo.getDomainTypeID(domain); + elements[typeID] = getTermForDomain(domain); + } + + auto *row = allocateRow(elements); + setDomainAssociation(op.getResult(), row); + return success(); +} + +LogicalResult InferModuleDomains::processOp(DomainDefineOp op) { + auto src = op.getSrc(); + auto dst = op.getDest(); + auto *srcTerm = getTermForDomain(src); + auto *dstTerm = getTermForDomain(dst); + if (failed(unify(dstTerm, srcTerm))) { + VariableIDTable idTable; + auto diag = op->emitOpError("failed to propagate source to destination"); + auto ¬e1 = diag.attachNote(); + note1 << "destination has underlying value: "; + render(note1, idTable, dstTerm); + + auto ¬e2 = diag.attachNote(src.getLoc()); + note2 << "source has underlying value: "; + render(note2, idTable, srcTerm); + } + return unify(dstTerm, srcTerm); +} + +template +T InferModuleDomains::updateInstancePorts(T op, + const ModuleUpdateInfo &update) { + auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions); + clone.setDomainInfoAttr(update.portDomainInfo); + op->erase(); + return clone; +} + +template +LogicalResult InferModuleDomains::processInstancePorts(T op) { + auto circuitInfo = globals.circuitInfo; + auto numDomainTypes = circuitInfo.getNumDomains(); + DenseMap domainPortTypeIDTable; + auto domainInfo = op.getDomainInfoAttr(); + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) { + Value port = op.getResult(i); + + LLVM_DEBUG(llvm::errs() << "handling instance port: " << port << "\n"); + + if (isa(port.getType())) { + auto typeID = circuitInfo.getDomainTypeID(domainInfo, i); + domainPortTypeIDTable[i] = typeID; + if (op.getPortDirection(i) == Direction::Out) { + setTermForDomain(port, allocate(port)); + } + continue; + } + + if (!isa(port.getType())) + continue; + + // This is a port, which may have explicit domain information. Associate the + // port with a row of domains, where each element is derived from the domain + // associations recorded in the domain info attribute of the instance. + SmallVector elements(numDomainTypes); + auto associations = getPortDomainAssociation(domainInfo, i); + for (auto domainPortIndexAttr : associations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto typeID = domainPortTypeIDTable[domainPortIndex]; + auto *term = getTermForDomain(op.getResult(domainPortIndex)); + elements[typeID] = term; + } + + // Confirm that we have complete domain information for the port. We can be + // missing information if, for example, this was an instance of an + // extmodule. + for (size_t domainTypeID = 0; domainTypeID < numDomainTypes; + ++domainTypeID) { + if (elements[domainTypeID]) + continue; + auto domainDecl = circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto portName = op.getPortNameAttr(i); + op->emitOpError() << "missing " << domainName << " association for port " + << portName; + return failure(); + } + + setDomainAssociation(port, allocateRow(elements)); + } + + return success(); +} + +LogicalResult InferModuleDomains::updateModule(FModuleOp op) { + initializeExportTable(op); + + generalizeModule(op); + if (failed(updatePortDomainAssociations(op))) + return failure(); + + if (failed(updateDomainAssociationsInBody(op))) + return failure(); + + return success(); +} + +void InferModuleDomains::initializeExportTable(FModuleOp module) { + size_t numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + if (!isa(type)) + continue; + auto value = getUnderlyingDomain(port); + if (value) + exportTable[value].push_back(port); + } +} + +LogicalResult +InferModuleDomains::updatePortDomainAssociations(FModuleOp module) { + // At this point, all domain variables mentioned in ports have been + // solved by generalizing the module (adding input domain ports). Now, we have + // to form the new port domain information for the module by examining the + // the associated domains of each port. + auto *context = module.getContext(); + auto numDomains = globals.circuitInfo.getNumDomains(); + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + auto oldModuleDomainInfo = module.getDomainInfoAttr(); + auto numPorts = module.getNumPorts(); + SmallVector newModuleDomainInfo(numPorts); + + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + // If the port is an output domain, we may need to drive the output with + // a value. If we don't know what value to drive to the port, error. + if (isa(type)) { + if (module.getPortDirection(i) == Direction::Out) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + + // Get the underlying value of the output port. + auto *term = getTermForDomain(port); + term = find(term); + auto *val = dyn_cast(term); + if (!val) { + emitDomainPortInferenceError(module, i); + return failure(); + } + + // If the output port is not driven, drive it. + if (!driven) { + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } + } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + continue; + } + + if (isa(type)) { + auto associations = copyPortDomainAssociations(oldModuleDomainInfo, i); + auto *row = getDomainAssociationAsRow(port); + for (size_t domainTypeID = 0; domainTypeID < numDomains; ++domainTypeID) { + if (associations[domainTypeID]) + continue; + + auto domain = cast(find(row->elements[domainTypeID]))->value; + auto &exports = exportTable[domain]; + if (exports.empty()) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "private " << domainName << " association for port " + << portName; + diag.attachNote(domain.getLoc()) << "associated domain: " << domain; + return failure(); + } + + if (exports.size() > 1) { + auto portName = module.getPortNameAttr(i); + auto portLoc = module.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + auto diag = emitError(portLoc) + << "ambiguous " << domainName << " association for port " + << portName; + for (auto arg : exports) { + auto name = module.getPortNameAttr(arg.getArgNumber()); + auto loc = module.getPortLocation(arg.getArgNumber()); + diag.attachNote(loc) << "candidate association " << name; + } + return failure(); + } + + auto argument = cast(exports[0]); + auto domainPortIndex = argument.getArgNumber(); + associations[domainTypeID] = IntegerAttr::get( + IntegerType::get(context, 32, IntegerType::Unsigned), + domainPortIndex); + } + + newModuleDomainInfo[i] = ArrayAttr::get(context, associations); + continue; + } + + newModuleDomainInfo[i] = oldModuleDomainInfo[i]; + } + + auto newModuleDomainInfoAttr = + ArrayAttr::get(module.getContext(), newModuleDomainInfo); + module.setDomainInfoAttr(newModuleDomainInfoAttr); + + // record the domain info for replaying on instances. + auto &update = globals.moduleUpdateTable[module.getNameAttr()]; + update.portDomainInfo = newModuleDomainInfoAttr; + + return success(); +} + +SmallVector +InferModuleDomains::copyPortDomainAssociations(ArrayAttr moduleDomainInfo, + size_t portIndex) { + SmallVector result(globals.circuitInfo.getNumDomains()); + auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex); + for (auto domainPortIndexAttr : oldAssociations) { + auto domainPortIndex = domainPortIndexAttr.getUInt(); + auto domainTypeID = + globals.circuitInfo.getDomainTypeID(moduleDomainInfo, domainPortIndex); + result[domainTypeID] = domainPortIndexAttr; + }; + return result; +} + +void InferModuleDomains::generalizeModule(FModuleOp module) { + PortInsertions insertions; + // If the port is hardware, we have to check the associated row of + // domains. If any associated domain is a variable, we solve the variable + // by generalizing the module with an additional input domain port. If any + // associated domain is defined internally to the module, we have to add + // an output domain port, to allow the domain to escape. + DenseMap pendingSolutions; + llvm::MapVector pendingExports; + + size_t inserted = 0; + auto numPorts = module.getNumPorts(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = module.getArgument(i); + auto type = port.getType(); + + if (!isa(type)) + continue; + + auto *row = getDomainAssociationAsRow(port); + for (auto [typeID, term] : llvm::enumerate(row->elements)) { + auto *domain = find(term); + + if (auto *val = dyn_cast(domain)) { + auto value = val->value; + // If the domain value is defined inside the module body, we must output + // export the domain, so it may appear in the signature of the + // module. + if (isPort(module, value)) + continue; + + // The domain is defined internally. If there value is already exported, + // or will be exported, we are done. + if (exportTable.contains(value) || pendingExports.contains(value)) + continue; + + // We must insert a new output domain port. + auto domainDecl = globals.circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::Out; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending export. + auto exportedPortIndex = inserted + portInsertionPoint; + pendingExports[val->value] = exportedPortIndex; + ++inserted; + } + + if (auto *var = dyn_cast(domain)) { + if (pendingSolutions.contains(var)) + continue; + + // insert a new input domain port for the variable. + auto domainDecl = globals.circuitInfo.getDomain(typeID); + auto domainName = domainDecl.getNameAttr(); + + auto portInsertionPoint = i; + auto portName = domainName; + auto portType = DomainType::get(module.getContext()); + auto portDirection = Direction::In; + auto portSym = StringAttr(); + auto portLoc = port.getLoc(); + auto portAnnos = std::nullopt; + auto portDomainInfo = FlatSymbolRefAttr::get(domainName); + PortInfo portInfo(portName, portType, portDirection, portSym, portLoc, + portAnnos, portDomainInfo); + insertions.push_back({portInsertionPoint, portInfo}); + + // Record the pending solution. + auto solutionPortIndex = inserted + portInsertionPoint; + pendingSolutions[var] = solutionPortIndex; + ++inserted; + } + } + } + + // Put the domain ports in place. + module.insertPorts(insertions); + + // Solve the variables and record them as "self-exporting". + for (auto [var, portIndex] : pendingSolutions) { + auto port = module.getArgument(portIndex); + auto *solution = allocate(port); + solve(var, solution); + exportTable[port].push_back(port); + } + + // Drive the pending exports. + auto builder = OpBuilder::atBlockEnd(module.getBodyBlock()); + for (auto [value, portIndex] : pendingExports) { + auto port = module.getArgument(portIndex); + DomainDefineOp::create(builder, port.getLoc(), port, value); + exportTable[value].push_back(port); + setTermForDomain(port, allocate(value)); + } + + // Record the insertions, so we can replay them on instances later. + auto &update = globals.moduleUpdateTable[module.getNameAttr()]; + update.portInsertions = std::move(insertions); +} + +LogicalResult +InferModuleDomains::updateDomainAssociationsInBody(FModuleOp module) { + auto result = success(); + module.getBodyBlock()->walk([&](Operation *op) -> WalkResult { + if (failed(updateOpDomainAssociations(op))) { + result = failure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return result; +} + +LogicalResult InferModuleDomains::updateOpDomainAssociations(Operation *op) { + if (auto instance = dyn_cast(op)) + return updateInstanceDomainAssociations(instance); + if (auto instance = dyn_cast(op)) + return updateInstanceDomainAssociations(instance); + return success(); +} + +template +LogicalResult InferModuleDomains::updateInstanceDomainAssociations(T op) { + auto *context = op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(op); + auto numPorts = op->getNumResults(); + for (size_t i = 0; i < numPorts; ++i) { + auto port = op.getResult(i); + auto type = port.getType(); + auto direction = op.getPortDirection(i); + if (isa(type)) { + if (direction == Direction::In) { + bool driven = false; + for (auto *user : port.getUsers()) { + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == port) { + driven = true; + break; + } + } + } + if (!driven) { + auto *term = getTermForDomain(port); + term = find(term); + if (auto *val = dyn_cast(term)) { + auto loc = port.getLoc(); + auto value = val->value; + DomainDefineOp::create(builder, loc, port, value); + } else { + emitDomainPortInferenceError(op, i); + return failure(); + } + } + } + } + } + return success(); +} + +LogicalResult InferModuleDomains::unifyAssociations(Operation *op, Value lhs, + Value rhs) { + LLVM_DEBUG(llvm::errs() << " unify associations of:\n"; + llvm::errs() << " lhs=" << lhs << "\n"; + llvm::errs() << " rhs=" << rhs << "\n";); + + if (!lhs || !rhs) + return success(); + + if (lhs == rhs) + return success(); + + auto *lhsTerm = getOptDomainAssociation(lhs); + auto *rhsTerm = getOptDomainAssociation(rhs); + + if (lhsTerm) { + if (rhsTerm) { + if (failed(unify(lhsTerm, rhsTerm))) { + auto diag = op->emitOpError("illegal domain crossing in operation"); + auto ¬e1 = diag.attachNote(lhs.getLoc()); + + note1 << "1st operand has domains: "; + VariableIDTable idTable; + render(note1, idTable, lhsTerm); + + auto ¬e2 = diag.attachNote(rhs.getLoc()); + note2 << "2nd operand has domains: "; + render(note2, idTable, rhsTerm); + + return failure(); + } + } + setDomainAssociation(rhs, lhsTerm); + return success(); + } + + if (rhsTerm) { + setDomainAssociation(lhs, rhsTerm); + return success(); + } + + auto *var = allocate(); + setDomainAssociation(lhs, var); + setDomainAssociation(rhs, var); + return success(); +} + +Value InferModuleDomains::getUnderlyingDomain(Value value) { + assert(isa(value.getType())); + auto *term = getOptTermForDomain(value); + if (auto *val = llvm::dyn_cast_if_present(term)) + return val->value; + return nullptr; +} + +Term *InferModuleDomains::getTermForDomain(Value value) { + assert(isa(value.getType())); + if (auto *term = getOptTermForDomain(value)) + return term; + auto *term = allocate(); + setTermForDomain(value, term); + return term; +} + +Term *InferModuleDomains::getOptTermForDomain(Value value) const { + assert(isa(value.getType())); + auto it = termTable.find(value); + if (it == termTable.end()) + return nullptr; + return find(it->second); +} + +void InferModuleDomains::setTermForDomain(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + assert(!termTable.contains(value)); + termTable.insert({value, term}); +} + +RowTerm *InferModuleDomains::getDomainAssociationAsRow(Value value) { + assert(isa(value.getType())); + auto *term = getOptDomainAssociation(value); + + // If the term is unknown, allocate a fresh row and set the association. + if (!term) { + auto *row = allocateRow(); + setDomainAssociation(value, row); + return row; + } + + // If the term is already a row, return it. + if (auto *row = dyn_cast(term)) + return row; + + // Otherwise, unify the term with a fresh row of domains. + if (auto *var = dyn_cast(term)) { + auto *row = allocateRow(); + solve(var, row); + return row; + } + + assert(false && "unhandled term type"); + return nullptr; +} + +Term *InferModuleDomains::getDomainAssociation(Value value) { + auto *term = getOptDomainAssociation(value); + if (term) + return term; + term = allocate(); + setDomainAssociation(value, term); + return term; +} + +Term *InferModuleDomains::getOptDomainAssociation(Value value) const { + assert(isa(value.getType())); + auto it = associationTable.find(value); + if (it == associationTable.end()) + return nullptr; + return find(it->second); +} + +void InferModuleDomains::setDomainAssociation(Value value, Term *term) { + assert(isa(value.getType())); + assert(term); + term = find(term); + associationTable.insert({value, term}); + LLVM_DEBUG(llvm::errs() << " set domain association: " << value; + llvm::errs() << " -> " << term << "\n";); +} + +RowTerm *InferModuleDomains::allocateRow() { + SmallVector elements; + elements.resize(globals.circuitInfo.getNumDomains()); + return allocateRow(elements); +} + +RowTerm *InferModuleDomains::allocateRow(ArrayRef elements) { + auto ds = allocateArray(elements); + return allocate(ds); +} + +template +T *InferModuleDomains::allocate(Args &&...args) { + static_assert(std::is_base_of_v, "T must be a term"); + return new (allocator) T(std::forward(args)...); +} + +ArrayRef InferModuleDomains::allocateArray(ArrayRef elements) { + auto size = elements.size(); + if (size == 0) + return {}; + + auto *result = allocator.Allocate(size); + llvm::uninitialized_copy(elements, result); + for (size_t i = 0; i < size; ++i) + if (!result[i]) + result[i] = allocate(); + + return ArrayRef(result, elements.size()); +} + +void InferModuleDomains::render(Diagnostic &out, Term *term) const { + VariableIDTable idTable; + render(out, idTable, term); +} + +// NOLINTNEXTLINE(misc-no-recursion) +void InferModuleDomains::render(Diagnostic &out, VariableIDTable &idTable, + Term *term) const { + term = find(term); + if (auto *var = dyn_cast(term)) { + out << "?" << idTable.get(var); + return; + } + if (auto *val = dyn_cast(term)) { + auto value = val->value; + auto [name, rooted] = getFieldName(FieldRef(value, 0), false); + out << name; + return; + } + if (auto *row = dyn_cast(term)) { + bool first = true; + out << "["; + for (size_t i = 0, e = globals.circuitInfo.getNumDomains(); i < e; ++i) { + auto domainOp = globals.circuitInfo.getDomain(i); + if (!first) { + out << ", "; + first = false; + } + out << domainOp.getName() << ": "; + render(out, idTable, row->elements[i]); + } + out << "]"; + return; + } +} + +template +void InferModuleDomains::emitPortDomainCrossingError(T op, size_t i, + size_t domainTypeID, + Term *term1, + Term *term2) const { + VariableIDTable idTable; + + auto portName = op.getPortNameAttr(i); + auto portLoc = op.getPortLocation(i); + auto domainDecl = globals.circuitInfo.getDomain(domainTypeID); + auto domainName = domainDecl.getNameAttr(); + + auto diag = emitError(portLoc); + diag << "illegal " << domainName << " crossing in port " << portName; + + auto ¬e1 = diag.attachNote(); + note1 << "1st instance: "; + render(note1, idTable, term1); + + auto ¬e2 = diag.attachNote(); + note2 << "2nd instance: "; + render(note2, idTable, term2); +} + +template +void InferModuleDomains::emitDomainPortInferenceError(T op, size_t i) const { + auto name = op.getPortNameAttr(i); + auto diag = emitError(op->getLoc()); + auto info = op.getDomainInfo(); + diag << "unable to infer value for domain port " << name; + for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) { + if (auto assocs = dyn_cast(info[j])) { + for (auto assoc : assocs) { + if (i == cast(assoc).getValue()) { + auto name = op.getPortNameAttr(j); + auto loc = op.getPortLocation(j); + diag.attachNote(loc) << "associated with hardware port " << name; + break; + } + } + } + } +} + +//===--------------------------------------------------------------------------- +// InferDomainsPass: Top-level pass implementation. +//===--------------------------------------------------------------------------- + +namespace { +struct InferDomainsPass + : public circt::firrtl::impl::InferDomainsBase { + void runOnOperation() override; +}; +} // namespace + +void InferDomainsPass::runOnOperation() { + LLVM_DEBUG(debugPassHeader(this) << "\n"); + auto circuit = getOperation(); + auto &instanceGraph = getAnalysis(); + + GlobalState globals(circuit); + DenseSet visited; + for (auto *root : instanceGraph) { + for (auto *node : llvm::post_order_ext(root, visited)) { + auto module = dyn_cast(node->getOperation()); + if (!module) + continue; + + if (failed(InferModuleDomains::run(globals, module))) { + signalPassFailure(); + return; + } + } + } + LLVM_DEBUG(debugFooter() << "\n"); +} diff --git a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp index 594f65c8591d..1dc2096b189f 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferReadWrite.cpp @@ -89,7 +89,7 @@ struct InferReadWritePass Attribute portAnno; portAnno = memOp.getPortAnnotation(portIt.index()); if (memOp.getPortKind(portIt.index()) == MemOp::PortKind::Debug) { - resultNames.push_back(memOp.getPortName(portIt.index())); + resultNames.push_back(memOp.getPortNameAttr(portIt.index())); resultTypes.push_back(memOp.getResult(portIt.index()).getType()); portAnnotations.push_back(portAnno); continue; diff --git a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp index dfef3719f3b3..dceb7c8a123e 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp @@ -705,7 +705,7 @@ static bool getDeclName(Value value, SmallString<32> &string) { string += op.getName(); string += "."; string += - op.getPortName(cast(value).getResultNumber()).getValue(); + op.getPortName(cast(value).getResultNumber()); return true; }) .Case([&](auto op) { diff --git a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp index 5722d989154a..af78bf5c804d 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp @@ -1439,7 +1439,7 @@ updateInstanceInClass(InstanceOp firrtlInstance, hw::HierPathOp hierPath, // The path to the field is just this output's name. auto objectFieldPath = builder.getArrayAttr({FlatSymbolRefAttr::get( - firrtlInstance.getPortName(result.getResultNumber()))}); + firrtlInstance.getPortNameAttr(result.getResultNumber()))}); // Create the field access. auto objectField = ObjectFieldOp::create( diff --git a/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp b/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp index eb4442a73886..df64d628d161 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerIntmodules.cpp @@ -108,7 +108,7 @@ void LowerIntmodulesPass::runOnOperation() { } outputs.push_back( OutputInfo{inst.getResult(idx), - BundleType::BundleElement(inst.getPortName(idx), + BundleType::BundleElement(inst.getPortNameAttr(idx), /*isFlip=*/false, ftype)}); } diff --git a/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp b/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp index bffbc3057569..375414ca866c 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp @@ -547,7 +547,7 @@ LogicalResult Visitor::visitDecl(InstanceOp op) { // If not identity, mark this port for eventual removal. portsToErase.set(newIndex); - auto portName = op.getPortName(index); + auto portName = op.getPortNameAttr(index); auto portDirection = op.getPortDirection(index); auto portDomain = op.getPortDomain(index); auto loc = op.getLoc(); diff --git a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp index ba54761c42f7..b62e03209389 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp @@ -997,7 +997,7 @@ bool TypeLoweringVisitor::visitDecl(MemOp op) { } auto wire = builder->create( result.getType(), - (op.getName() + "_" + op.getPortName(index).getValue()).str()); + (op.getName() + "_" + op.getPortName(index)).str()); oldPorts.push_back(wire); result.replaceAllUsesWith(wire.getResult()); } @@ -1498,13 +1498,13 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector fieldTypes; if (!peelType(srcType, fieldTypes, mode)) { newDirs.push_back(op.getPortDirection(i)); - newNames.push_back(op.getPortName(i)); + newNames.push_back(op.getPortNameAttr(i)); newDomains.push_back(builder->getArrayAttr({})); resultTypes.push_back(srcType); newPortAnno.push_back(oldPortAnno[i]); } else { skip = false; - auto oldName = op.getPortNameStr(i); + auto oldName = op.getPortName(i); auto oldDir = op.getPortDirection(i); // Store the flat type for the new bundle type. for (const auto &field : fieldTypes) { diff --git a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp index d7eb00a47272..596a6c87b799 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp @@ -806,7 +806,7 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase { for (const auto &res : llvm::enumerate(mem.getResults())) { if (isa(mem.getResult(res.index()).getType())) continue; - resultNames.push_back(mem.getPortName(res.index())); + resultNames.push_back(mem.getPortNameAttr(res.index())); resultTypes.push_back(res.value().getType()); portAnnotations.push_back(mem.getPortAnnotation(res.index())); oldResults.push_back(res.value()); diff --git a/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp b/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp index ef2c0dcecf09..b671ebd31af6 100644 --- a/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp +++ b/lib/Dialect/FIRRTL/Transforms/MemToRegOfVec.cpp @@ -386,7 +386,7 @@ struct MemToRegOfVecPass // simpler to delete the memOp. auto wire = WireOp::create( builder, result.getType(), - (memOp.getName() + "_" + memOp.getPortName(index).getValue()).str(), + (memOp.getName() + "_" + memOp.getPortName(index)).str(), memOp.getNameKind()); result.replaceAllUsesWith(wire.getResult()); result = wire.getResult(); diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index 6d846422507b..602053ed22fe 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -49,6 +49,9 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, pm.nest().nest().addPass( firrtl::createLowerIntrinsics()); + if (opt.shouldInferDomains()) + pm.nest().addPass(firrtl::createInferDomains()); + return success(); } @@ -757,6 +760,11 @@ struct FirtoolCmdOptions { llvm::cl::desc("Emit bindfiles for private modules"), llvm::cl::init(false)}; + llvm::cl::opt inferDomains{ + "infer-domains", + llvm::cl::desc("Enable domain inference and checking"), + llvm::cl::init(false)}; + //===----------------------------------------------------------------------=== // Lint options //===----------------------------------------------------------------------=== @@ -808,7 +816,7 @@ circt::firtool::FirtoolOptions::FirtoolOptions() disableCSEinClasses(false), selectDefaultInstanceChoice(false), symbolicValueLowering(verif::SymbolicValueLowering::ExtModule), disableWireElimination(false), lintStaticAsserts(true), - lintXmrsInDesign(true), emitAllBindFiles(false) { + lintXmrsInDesign(true), emitAllBindFiles(false), inferDomains(false) { if (!clOptions.isConstructed()) return; outputFilename = clOptions->outputFilename; @@ -861,4 +869,5 @@ circt::firtool::FirtoolOptions::FirtoolOptions() lintStaticAsserts = clOptions->lintStaticAsserts; lintXmrsInDesign = clOptions->lintXmrsInDesign; emitAllBindFiles = clOptions->emitAllBindFiles; + inferDomains = clOptions->inferDomains; } diff --git a/test/Dialect/FIRRTL/infer-domains-errors.mlir b/test/Dialect/FIRRTL/infer-domains-errors.mlir new file mode 100644 index 000000000000..da419dcac465 --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains-errors.mlir @@ -0,0 +1,128 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s --verify-diagnostics --split-input-file + +// Port annotated with same domain type twice. +firrtl.circuit "DomainCrossOnPort" { + firrtl.domain @ClockDomain + firrtl.module @DomainCrossOnPort( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-error @below {{illegal "ClockDomain" crossing in port "p"}} + // expected-note @below {{1st instance: A}} + // expected-note @below {{2nd instance: B}} + in %p: !firrtl.uint<1> domains [%A, %B] + ) {} +} + +// ----- + +// Illegal domain crossing - connect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.connect %b, %a : !firrtl.uint<1> + } +} + +// ----- + +// Illegal domain crossing at matchingconnect op. +firrtl.circuit "IllegalDomainCrossing" { + firrtl.domain @ClockDomain + firrtl.module @IllegalDomainCrossing( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + // expected-note @below {{2nd operand has domains: [ClockDomain: A]}} + in %a: !firrtl.uint<1> domains [%A], + // expected-note @below {{1st operand has domains: [ClockDomain: B]}} + out %b: !firrtl.uint<1> domains [%B] + ) { + // expected-error @below {{illegal domain crossing in operation}} + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// ----- + +// Unable to infer domain of port, when port is driven by constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstant" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<1>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstant() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %c0_ui1 : !firrtl.uint<1> + } +} + +// ----- + +// Unable to infer domain of port, when port is driven by arithmetic on constant. +firrtl.circuit "UnableToInferDomainOfPortDrivenByConstantExpr" { + firrtl.domain @ClockDomain + firrtl.module @Foo(in %i: !firrtl.uint<2>) {} + + firrtl.module @UnableToInferDomainOfPortDrivenByConstantExpr() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %0 = firrtl.add %c0_ui1, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<2> + // expected-error @below {{unable to infer value for domain port "ClockDomain"}} + // expected-note @below {{associated with hardware port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<2>) + firrtl.matchingconnect %foo_i, %0 : !firrtl.uint<2> + } +} + +// ----- + +// Incomplete extmodule domain information. + +firrtl.circuit "IncompleteDomainInfoForExtModule" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in i: !firrtl.uint<1>) + + firrtl.module @IncompleteDomainInfoForExtModule(in %i: !firrtl.uint<1>) { + // expected-error @below {{'firrtl.instance' op missing "ClockDomain" association for port "i"}} + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.matchingconnect %foo_i, %i : !firrtl.uint<1> + } +} + +// ----- + +// Domain not exported like it should be. + +// ----- + +// Domain exported multiple times. Which do we choose? + +firrtl.circuit "DoubleExportOfDomain" { + firrtl.domain @ClockDomain + + firrtl.module @DoubleExportOfDomain( + // expected-note @below {{candidate association "DI"}} + in %DI : !firrtl.domain of @ClockDomain, + // expected-note @below {{candidate association "DO"}} + out %DO : !firrtl.domain of @ClockDomain, + in %i : !firrtl.uint<1> domains [%DO], + // expected-error @below {{ambiguous "ClockDomain" association for port "o"}} + out %o : !firrtl.uint<1> domains [] + ) { + // DI and DO are aliases + firrtl.domain.define %DO, %DI + + // o is on same domain as i + firrtl.matchingconnect %o, %i : !firrtl.uint<1> + } +} + diff --git a/test/Dialect/FIRRTL/infer-domains.mlir b/test/Dialect/FIRRTL/infer-domains.mlir new file mode 100644 index 000000000000..c2cc0d96a83e --- /dev/null +++ b/test/Dialect/FIRRTL/infer-domains.mlir @@ -0,0 +1,250 @@ +// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-infer-domains))' %s | FileCheck %s + +// Legal domain usage - no crossing. +firrtl.circuit "LegalDomains" { + firrtl.domain @ClockDomain + firrtl.module @LegalDomains( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // Connecting within the same domain is legal. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "LegalDomains" + +// Domain inference through connections. +firrtl.circuit "DomainInference" { + firrtl.domain @ClockDomain + firrtl.module @DomainInference( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> + ) { + %b = firrtl.wire : !firrtl.uint<1> // No explicit domain + + // This should infer that %b is in domain %A. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + + // This should be legal since %b is now inferred to be in domain %A. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "DomainInference" +// CHECK: out %c: !firrtl.uint<1> domains [%A] + +// Unsafe domain cast +firrtl.circuit "UnsafeDomainCast" { + firrtl.domain @ClockDomain + firrtl.module @UnsafeDomainCast( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A], + out %c: !firrtl.uint<1> domains [%B] + ) { + // Unsafe cast from domain A to domain B. + %b = firrtl.unsafe_domain_cast %a domains %B : !firrtl.uint<1> + + // This should be legal since we explicitly cast. + firrtl.matchingconnect %c, %b : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "UnsafeDomainCast" + +// Domain sequence matching. +firrtl.circuit "LegalSequences" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @LegalSequences( + in %C: !firrtl.domain of @ClockDomain, + in %P: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%C, %P], + out %b: !firrtl.uint<1> domains [%C, %P] + ) { + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Domain sequence order equivalence - should be legal +firrtl.circuit "SequenceOrderEquivalence" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceOrderEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %b: !firrtl.uint<1> domains [%B, %A] + ) { + // This should be legal since domain order doesn't matter in canonical representation + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} +// CHECK-LABEL: firrtl.circuit "SequenceOrderEquivalence" + +// Domain sequence inference +firrtl.circuit "SequenceInference" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + firrtl.module @SequenceInference( + in %A: !firrtl.domain of @ClockDomain, + in %B: !firrtl.domain of @PowerDomain, + in %a: !firrtl.uint<1> domains [%A, %B], + out %d: !firrtl.uint<1> + ) { + %c = firrtl.wire : !firrtl.uint<1> + + // %c should infer domain sequence [%A, %B] + firrtl.matchingconnect %c, %a : !firrtl.uint<1> + + // This should be legal since %c has inferred [%A, %B] + firrtl.matchingconnect %d, %c : !firrtl.uint<1> + } +} + +// Domain duplicate equivalence - should be legal. +firrtl.circuit "DuplicateDomainEquivalence" { + firrtl.domain @ClockDomain + firrtl.module @DuplicateDomainEquivalence( + in %A: !firrtl.domain of @ClockDomain, + in %a: !firrtl.uint<1> domains [%A, %A], + out %b: !firrtl.uint<1> domains [%A] + ) { + // This should be legal since duplicate domains are canonicalized. + firrtl.matchingconnect %b, %a : !firrtl.uint<1> + } +} + +// Unsafe domain cast with sequences +firrtl.circuit "UnsafeSequenceCast" { + firrtl.domain @ClockDomain + firrtl.domain @PowerDomain + + firrtl.module @UnsafeSequenceCast( + in %C1: !firrtl.domain of @ClockDomain, + in %C2: !firrtl.domain of @ClockDomain, + in %P1: !firrtl.domain of @PowerDomain, + in %i: !firrtl.uint<1> domains [%C1, %P1], + out %o: !firrtl.uint<1> domains [%C2, %P1] + ) { + %0 = firrtl.unsafe_domain_cast %i domains %C2 : !firrtl.uint<1> + firrtl.matchingconnect %o, %0 : !firrtl.uint<1> + } +} + +// Different port types domain inference. + +// CHECK-LABEL: DifferentPortTypes +firrtl.circuit "DifferentPortTypes" { + firrtl.domain @ClockDomain + firrtl.module @DifferentPortTypes( + in %A: !firrtl.domain of @ClockDomain, + in %uint_input: !firrtl.uint<8> domains [%A], + in %sint_input: !firrtl.sint<4> domains [%A], + out %uint_output: !firrtl.uint<8>, + out %sint_output: !firrtl.sint<4> + ) { + firrtl.matchingconnect %uint_output, %uint_input : !firrtl.uint<8> + firrtl.matchingconnect %sint_output, %sint_input : !firrtl.sint<4> + } +} + +// Domain inference through wires. + +// CHECK-LABEL: DomainInferenceThroughWires +firrtl.circuit "DomainInferenceThroughWires" { + firrtl.domain @ClockDomain + firrtl.module @DomainInferenceThroughWires( + in %A: !firrtl.domain of @ClockDomain, + in %input: !firrtl.uint<1> domains [%A], + // CHECK: out %output: !firrtl.uint<1> domains [%A] + out %output: !firrtl.uint<1> + ) { + %wire1 = firrtl.wire : !firrtl.uint<1> + %wire2 = firrtl.wire : !firrtl.uint<1> + + firrtl.matchingconnect %wire1, %input : !firrtl.uint<1> + firrtl.matchingconnect %wire2, %wire1 : !firrtl.uint<1> + firrtl.matchingconnect %output, %wire2 : !firrtl.uint<1> + } +} + +// Register inference. + +// CHECK-LABEL: RegisterInference +firrtl.circuit "RegisterInference" { + firrtl.domain @ClockDomain + firrtl.module @RegisterInference( + in %A: !firrtl.domain of @ClockDomain, + in %clock: !firrtl.clock domains [%A], + // CHECK: in %d: !firrtl.uint<1> domains [%A] + in %d: !firrtl.uint<1>, + // CHECK: out %q: !firrtl.uint<1> domains [%A] + out %q: !firrtl.uint<1> + ) { + %r = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1> + firrtl.matchingconnect %r, %d : !firrtl.uint<1> + firrtl.matchingconnect %q, %r : !firrtl.uint<1> + } +} + +// Update domain on instance. + +// CHECK-LABEL: InstanceUpdate +firrtl.circuit "InstanceUpdate" { + firrtl.domain @ClockDomain + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %foo_ClockDomain, %foo_i = firrtl.instance foo @Foo(in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %foo_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %foo_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceUpdate(in %i : !firrtl.uint<1>) { + %foo_i = firrtl.instance foo @Foo(in i: !firrtl.uint<1>) + firrtl.connect %foo_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: InstanceChoiceUpdate +firrtl.circuit "InstanceChoiceUpdate" { + firrtl.domain @ClockDomain + + firrtl.option @Option { + firrtl.option_case @X + firrtl.option_case @Y + } + + firrtl.module @Foo(in %i : !firrtl.uint<1>) {} + firrtl.module @Bar(in %i : !firrtl.uint<1>) {} + firrtl.module @Baz(in %i : !firrtl.uint<1>) {} + + // CHECK: firrtl.module @InstanceChoiceUpdate(in %ClockDomain: !firrtl.domain of @ClockDomain, in %i: !firrtl.uint<1> domains [%ClockDomain]) { + // CHECK: %inst_ClockDomain, %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in ClockDomain: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [ClockDomain]) + // CHECK: firrtl.domain.define %inst_ClockDomain, %ClockDomain + // CHECK: firrtl.connect %inst_i, %i : !firrtl.uint<1> + // CHECK: } + firrtl.module @InstanceChoiceUpdate(in %i : !firrtl.uint<1>) { + %inst_i = firrtl.instance_choice inst @Foo alternatives @Option { @X -> @Bar, @Y -> @Baz } (in i : !firrtl.uint<1>) + firrtl.connect %inst_i, %i : !firrtl.uint<1>, !firrtl.uint<1> + } +} + +// CHECK-LABEL: ConstantInMultipleDomains +firrtl.circuit "ConstantInMultipleDomains" { + firrtl.domain @ClockDomain + + firrtl.extmodule @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + + firrtl.module @ConstantInMultipleDomains(in %A: !firrtl.domain of @ClockDomain, in %B: !firrtl.domain of @ClockDomain) { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %x_A, %x_i = firrtl.instance x @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %x_A, %A + firrtl.matchingconnect %x_i, %c0_ui1 : !firrtl.uint<1> + + %y_A, %y_i = firrtl.instance y @Foo(in A: !firrtl.domain of @ClockDomain, in i: !firrtl.uint<1> domains [A]) + firrtl.domain.define %y_A, %B + firrtl.matchingconnect %y_i, %c0_ui1 : !firrtl.uint<1> + } +}