From 54dc3729a0b4ffdfe39136e33a775a0b31190b64 Mon Sep 17 00:00:00 2001 From: Fabian Schuiki Date: Wed, 22 Jan 2025 20:04:00 -0800 Subject: [PATCH] [FIRRTL][IMCP] Overdefine ports of modules with unknown symbol uses If a module is referenced from an unknown top-level operation, i.e. an operation that is not an `hw.hierpath`, mark the module's inputs as overdefined. IMCP cannot reason about how the module is used by such an unknown operation, and therefore should assume that the operation might instantiate the module and apply arbitrary values to its input. As an example, the `firrtl.formal` operation may refer to a private module as to be executed as a formal test, applying symbolic values to the module's inputs. While IMCP could simply special-case the `firrtl.formal` operation, it feels cleaner to make the pass defensive in the presence of _any_ operation which it does not explicitly know how to deal with. --- include/circt/Support/InstanceGraph.h | 27 ++++++++++-- lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp | 44 ++++++++++++++++--- lib/Support/InstanceGraph.cpp | 9 ++-- test/Dialect/FIRRTL/imconstprop.mlir | 37 +++++++++++++++- 4 files changed, 100 insertions(+), 17 deletions(-) diff --git a/include/circt/Support/InstanceGraph.h b/include/circt/Support/InstanceGraph.h index 5a2184b73418..35e6bbaf28e9 100644 --- a/include/circt/Support/InstanceGraph.h +++ b/include/circt/Support/InstanceGraph.h @@ -200,11 +200,30 @@ class InstanceGraph { InstanceGraph(const InstanceGraph &) = delete; virtual ~InstanceGraph() = default; - /// Look up an InstanceGraphNode for a module. - InstanceGraphNode *lookup(ModuleOpInterface op); + /// Lookup an module by name. Returns null if no module with the given name + /// exists in the instance graph. + InstanceGraphNode *lookupOrNull(StringAttr name); + + /// Look up an InstanceGraphNode for a module. Returns null if the module has + /// not been added to the instance graph. + InstanceGraphNode *lookupOrNull(ModuleOpInterface op) { + return lookup(op.getModuleNameAttr()); + } + + /// Look up an InstanceGraphNode for a module. Aborts if the module does not + /// exist. + InstanceGraphNode *lookup(ModuleOpInterface op) { + auto *node = lookupOrNull(op); + assert(node != nullptr && "Module not in InstanceGraph!"); + return node; + } - /// Lookup an module by name. - InstanceGraphNode *lookup(StringAttr name); + /// Lookup an module by name. Aborts if the module does not exist. + InstanceGraphNode *lookup(StringAttr name) { + auto *node = lookupOrNull(name); + assert(node != nullptr && "Module not in InstanceGraph!"); + return node; + } /// Lookup an InstanceGraphNode for a module. InstanceGraphNode *operator[](ModuleOpInterface op) { return lookup(op); } diff --git a/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp b/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp index 92ec2d00e33a..844fc1ea52da 100644 --- a/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp +++ b/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp @@ -369,12 +369,44 @@ void IMConstPropPass::runOnOperation() { instanceGraph = &getAnalysis(); - // Mark the input ports of public modules as being overdefined. - for (auto module : circuit.getBodyBlock()->getOps()) { - if (module.isPublic()) { - markBlockExecutable(module.getBodyBlock()); - for (auto port : module.getBodyBlock()->getArguments()) - markOverdefined(port); + // Mark input ports as overdefined where appropriate. + for (auto &op : circuit.getOps()) { + // Inputs of public modules are overdefined. + if (auto module = dyn_cast(op)) { + if (module.isPublic()) { + markBlockExecutable(module.getBodyBlock()); + for (auto port : module.getBodyBlock()->getArguments()) + markOverdefined(port); + } + continue; + } + + // Otherwise we check whether the top-level operation contains any + // references to modules. Symbol uses in NLAs are ignored. + if (isa(op)) + continue; + + // Inputs of modules referenced by unknown operations are overdefined, since + // we don't know how those operations affect the input port values. This + // handles things like `firrtl.formal`, which may may assign symbolic values + // to input ports of a private module. + auto symbolUses = SymbolTable::getSymbolUses(&op); + if (!symbolUses) + continue; + for (const auto &use : *symbolUses) { + if (auto symRef = dyn_cast(use.getSymbolRef())) { + if (auto *igNode = instanceGraph->lookupOrNull(symRef.getAttr())) { + if (auto module = dyn_cast(*igNode->getModule())) { + LLVM_DEBUG(llvm::dbgs() + << "Unknown use of " << module.getModuleNameAttr() + << " in " << op.getName() + << ", marking inputs as overdefined\n"); + markBlockExecutable(module.getBodyBlock()); + for (auto port : module.getBodyBlock()->getArguments()) + markOverdefined(port); + } + } + } } } diff --git a/lib/Support/InstanceGraph.cpp b/lib/Support/InstanceGraph.cpp index b93184df0aee..e1f14091f3da 100644 --- a/lib/Support/InstanceGraph.cpp +++ b/lib/Support/InstanceGraph.cpp @@ -106,16 +106,13 @@ void InstanceGraph::erase(InstanceGraphNode *node) { nodes.erase(node); } -InstanceGraphNode *InstanceGraph::lookup(StringAttr name) { +InstanceGraphNode *InstanceGraph::lookupOrNull(StringAttr name) { auto it = nodeMap.find(name); - assert(it != nodeMap.end() && "Module not in InstanceGraph!"); + if (it == nodeMap.end()) + return nullptr; return it->second; } -InstanceGraphNode *InstanceGraph::lookup(ModuleOpInterface op) { - return lookup(cast(op).getModuleNameAttr()); -} - void InstanceGraph::replaceInstance(InstanceOpInterface inst, InstanceOpInterface newInst) { assert(inst.getReferencedModuleNamesAttr() == diff --git a/test/Dialect/FIRRTL/imconstprop.mlir b/test/Dialect/FIRRTL/imconstprop.mlir index 552e5d4f3b41..73d11a0aafaa 100644 --- a/test/Dialect/FIRRTL/imconstprop.mlir +++ b/test/Dialect/FIRRTL/imconstprop.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -pass-pipeline='builtin.module(firrtl.circuit(firrtl-imconstprop))' --split-input-file %s | FileCheck %s +// RUN: circt-opt --firrtl-imconstprop --split-input-file --allow-unregistered-dialect %s | FileCheck %s firrtl.circuit "Test" { @@ -896,3 +896,38 @@ firrtl.circuit "Layers" { } } } + +// ----- + +// CHECK-LABEL: firrtl.circuit "PublicTop" +firrtl.circuit "PublicTop" { + // CHECK-LABEL: firrtl.module private @Foo + firrtl.module private @Foo(in %a: !firrtl.uint<1>) { + // CHECK-NOT: firrtl.constant 0 + // CHECK: firrtl.int.verif.assert %a + firrtl.int.verif.assert %a : !firrtl.uint<1> + } + // CHECK-LABEL: firrtl.module private @Bar + firrtl.module private @Bar(in %a: !firrtl.uint<1>) { + // CHECK-NOT: firrtl.constant 0 + // CHECK: firrtl.int.verif.assert %a + firrtl.int.verif.assert %a : !firrtl.uint<1> + } + firrtl.module @PublicTop() { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + %foo_a = firrtl.instance foo @Foo(in a: !firrtl.uint<1>) + %bar_a = firrtl.instance bar @Bar(in a: !firrtl.uint<1>) + firrtl.matchingconnect %foo_a, %c0_ui1 : !firrtl.uint<1> + firrtl.matchingconnect %bar_a, %c0_ui1 : !firrtl.uint<1> + } + firrtl.module private @PrivateTop1(in %a: !firrtl.uint<1>) { + %foo_a = firrtl.instance foo @Foo(in a: !firrtl.uint<1>) + firrtl.matchingconnect %foo_a, %a : !firrtl.uint<1> + } + firrtl.module private @PrivateTop2(in %a: !firrtl.uint<1>) { + %bar_a = firrtl.instance bar @Bar(in a: !firrtl.uint<1>) + firrtl.matchingconnect %bar_a, %a : !firrtl.uint<1> + } + firrtl.formal @Test, @PrivateTop1 {} + "some_unknown_dialect.op"() { magic = @PrivateTop2 } : () -> () +}