diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala index df98f72f2a..82c00ca5c7 100644 --- a/src/main/scala/firrtl/RenameMap.scala +++ b/src/main/scala/firrtl/RenameMap.scala @@ -4,7 +4,9 @@ package firrtl import annotations._ import firrtl.RenameMap.IllegalRenameException +import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.TargetToken.{Field, Index, Instance, OfModule} +import TargetUtils.{instKeyPathToTarget, unfoldInstanceTargets} import scala.collection.mutable @@ -21,6 +23,58 @@ object RenameMap { rm } + /** RenameMap factory for simple renaming of instances + * + * @param graph [[InstanceKeyGraph]] from *before* renaming + * @param renames Mapping of old instance name to new within Modules + */ + private[firrtl] def fromInstanceRenames( + graph: InstanceKeyGraph, + renames: Map[OfModule, Map[Instance, Instance]] + ): RenameMap = { + def renameAll(it: InstanceTarget): InstanceTarget = { + var prevMod = OfModule(it.module) + val pathx = it.path.map { + case (inst, of) => + val instx = renames + .get(prevMod) + .flatMap(_.get(inst)) + .getOrElse(inst) + prevMod = of + instx -> of + } + // Sanity check, the last one should always be a rename (or we wouldn't be calling this method) + val instx = renames(prevMod)(Instance(it.instance)) + it.copy(path = pathx, instance = instx.value) + } + val underlying = new mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] + val instOf: String => Map[String, String] = + graph.getChildInstances.toMap + // Laziness here is desirable, we only access each key once, some we don't access + .mapValues(_.map(k => k.name -> k.module).toMap) + for ((OfModule(module), instMapping) <- renames) { + val modLookup = instOf(module) + val parentInstances = graph.findInstancesInHierarchy(module) + for { + // For every instance of the Module where the renamed instance resides + parent <- parentInstances + parentTarget = instKeyPathToTarget(parent) + // Create the absolute InstanceTarget to be renamed + (Instance(from), _) <- instMapping // The to is given by renameAll + instMod = modLookup(from) + fromTarget = parentTarget.instOf(from, instMod) + // Ensure all renames apply to the InstanceTarget + toTarget = renameAll(fromTarget) + // RenameMap only allows 1 hit when looking up InstanceTargets, so rename all possible + // paths to this instance + (fromx, tox) <- unfoldInstanceTargets(fromTarget).zip(unfoldInstanceTargets(toTarget)) + } yield { + underlying(fromx) = List(tox) + } + } + new RenameMap(underlying) + } + /** Initialize a new RenameMap */ def apply(): RenameMap = new RenameMap diff --git a/src/main/scala/firrtl/annotations/TargetUtils.scala b/src/main/scala/firrtl/annotations/TargetUtils.scala new file mode 100644 index 0000000000..164c430b55 --- /dev/null +++ b/src/main/scala/firrtl/annotations/TargetUtils.scala @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.annotations + +import firrtl._ +import firrtl.analyses.InstanceKeyGraph +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.annotations.TargetToken._ + +object TargetUtils { + + /** Turns an instance path into a corresponding [[IsModule]] + * + * @note First InstanceKey is treated as the [[CircuitTarget]] + * @param path Instance path + * @param start Module in instance path to be starting [[ModuleTarget]] + * @return [[IsModule]] corresponding to Instance path + */ + def instKeyPathToTarget(path: Seq[InstanceKey], start: Option[String] = None): IsModule = { + val head = path.head + val startx = start.getOrElse(head.module) + val top: IsModule = CircuitTarget(head.module).module(startx) // ~Top|Start + val pathx = path.dropWhile(_.module != startx) + if (pathx.isEmpty) top + else pathx.tail.foldLeft(top) { case (acc, key) => acc.instOf(key.name, key.module) } + } + + /** Calculates all [[InstanceTarget]]s that refer to the given [[IsModule]] + * + * {{{ + * ~Top|Top/a:A/b:B/c:C unfolds to: + * * ~Top|Top/a:A/b:B/c:C + * * ~Top|A/b:B/c:C + * * ~Top|B/c:C + * }}} + * @note [[ModuleTarget]] arguments return an empty Iterable + */ + def unfoldInstanceTargets(ismod: IsModule): Iterable[InstanceTarget] = { + // concretely use List which is fast in practice + def rec(im: IsModule): List[InstanceTarget] = im match { + case inst: InstanceTarget => inst :: rec(inst.stripHierarchy(1)) + case _ => Nil + } + rec(ismod) + } +} diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 592caf5d17..0bd44a8ca3 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -8,8 +8,10 @@ import firrtl.annotations.{ MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, - ReferenceTarget + ReferenceTarget, + TargetToken } +import TargetToken.{Instance, OfModule} import firrtl.{ CircuitForm, CircuitState, @@ -73,16 +75,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule) val c = CircuitTarget(state.circuit.main) - val resultAndRenames = state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()))) + val refRenameMap = RenameMap() + val resultAndRenames = + state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq()), refRenameMap)) val result = state.circuit.copy(modules = resultAndRenames.map(_._1)) // memory init annotations could have been modified val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3) - // chain module renames in topological order - val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap - val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator - val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b)) + // Build RenameMap for instances + val moduleRenames = resultAndRenames.map { case (m, r, _) => OfModule(m.name) -> r }.toMap + val instRenameMap = RenameMap.fromInstanceRenames(InstanceKeyGraph(state.circuit), moduleRenames) + val renames = instRenameMap.andThen(refRenameMap) state.copy(circuit = result, renames = Some(renames), annotations = newAnnos) } @@ -90,9 +94,9 @@ object LowerTypes extends Transform with DependencyAPIMigration { private def onModule( c: CircuitTarget, m: DefModule, - memoryInit: Seq[MemoryInitAnnotation] - ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { - val renameMap = RenameMap() + memoryInit: Seq[MemoryInitAnnotation], + renameMap: RenameMap + ): (DefModule, Map[Instance, Instance], Seq[MemoryInitAnnotation]) = { val ref = c.module(m.name) // first we lower the ports in order to ensure that their names are independent of the module body @@ -105,7 +109,9 @@ object LowerTypes extends Transform with DependencyAPIMigration { implicit val memInit: Seq[MemoryInitAnnotation] = memoryInit val newMod = mLoweredPorts.mapStmt(onStatement) - (newMod, renameMap, memInit) + val instRenames = symbols.getInstanceRenames.toMap + + (newMod, instRenames, memInit) } // We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names. @@ -221,6 +227,7 @@ private class LoweringTable( private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames // Serialized old access string to new ground type reference. private val nameToExprs = mutable.HashMap[String, Seq[RefLikeExpression]]() ++ portNameToExprs + private val instRenames = mutable.ListBuffer[(Instance, Instance)]() def lower(mem: DefMemory): Seq[DefMemory] = { val (mems, refs) = DestructTypes.destructMemory(m, mem, namespace, renameMap, portNames) @@ -228,7 +235,7 @@ private class LoweringTable( mems } def lower(inst: DefInstance): DefInstance = { - val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, renameMap, portNames) + val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, instRenames, portNames) nameToExprs ++= refs.map { case (name, r) => name -> List(r) } newInst } @@ -245,6 +252,7 @@ private class LoweringTable( } def getReferences(expr: RefLikeExpression): Seq[RefLikeExpression] = nameToExprs(serialize(expr)) + def getInstanceRenames: List[(Instance, Instance)] = instRenames.toList // We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left. private def serialize(expr: RefLikeExpression): String = expr match { @@ -296,11 +304,11 @@ private object DestructTypes { * instead of a flat Reference when turning them into access expressions. */ def destructInstance( - m: ModuleTarget, - instance: DefInstance, - namespace: Namespace, - renameMap: RenameMap, - reserved: Set[String] + m: ModuleTarget, + instance: DefInstance, + namespace: Namespace, + instRenames: mutable.ListBuffer[(Instance, Instance)], + reserved: Set[String] ): (DefInstance, Seq[(String, SubField)]) = { val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved) val newName = rename.map(_.name).getOrElse(instance.name) @@ -314,7 +322,7 @@ private object DestructTypes { // rename all references to the instance if necessary if (newName != instance.name) { - renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module)) + instRenames += Instance(instance.name) -> Instance(newName) } // The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered. diff --git a/src/test/scala/firrtl/RenameMapPrivateSpec.scala b/src/test/scala/firrtl/RenameMapPrivateSpec.scala new file mode 100644 index 0000000000..d735e6c876 --- /dev/null +++ b/src/test/scala/firrtl/RenameMapPrivateSpec.scala @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl + +import firrtl.annotations.Target +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph +import firrtl.testutils.FirrtlFlatSpec + +class RenameMapPrivateSpec extends FirrtlFlatSpec { + "RenameMap.fromInstanceRenames" should "handle instance renames" in { + def tar(str: String): Target = Target.deserialize(str) + val circuit = parse( + """circuit Top : + | module Bar : + | skip + | module Foo : + | inst bar of Bar + | module Top : + | inst foo1 of Foo + | inst foo2 of Foo + | inst bar of Bar + |""".stripMargin + ) + val graph = InstanceKeyGraph(circuit) + val renames = Map( + OfModule("Foo") -> Map(Instance("bar") -> Instance("bbb")), + OfModule("Top") -> Map(Instance("foo1") -> Instance("ffff")) + ) + val rm = RenameMap.fromInstanceRenames(graph, renames) + rm.get(tar("~Top|Top/foo1:Foo")) should be(Some(Seq(tar("~Top|Top/ffff:Foo")))) + rm.get(tar("~Top|Top/foo2:Foo")) should be(None) + // Check of nesting + rm.get(tar("~Top|Top/foo1:Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Top/ffff:Foo/bbb:Bar")))) + rm.get(tar("~Top|Top/foo2:Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Top/foo2:Foo/bbb:Bar")))) + rm.get(tar("~Top|Foo/bar:Bar")) should be(Some(Seq(tar("~Top|Foo/bbb:Bar")))) + rm.get(tar("~Top|Top/bar:Bar")) should be(None) + } +} diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala index 70fa51fdc1..7ca9854496 100644 --- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala +++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala @@ -2,10 +2,13 @@ package firrtl.passes import firrtl.annotations.{CircuitTarget, IsMember} +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph import firrtl.{CircuitState, RenameMap, Utils} import firrtl.options.Dependency import firrtl.stage.TransformManager import firrtl.stage.TransformManager.TransformDependency +import firrtl.testutils.FirrtlMatchers import org.scalatest.flatspec.AnyFlatSpec /** Unit test style tests for [[LowerTypes]]. @@ -228,22 +231,35 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { } /** Instances are a special case since they do not get completely destructed but instead become a 1-deep bundle. */ -class LowerTypesOfInstancesSpec extends AnyFlatSpec { +class LowerTypesOfInstancesSpec extends AnyFlatSpec with FirrtlMatchers { import LowerTypesSpecUtils._ private case class Lower(inst: firrtl.ir.DefInstance, fields: Seq[String], renameMap: RenameMap) private val m = CircuitTarget("m").module("m") + private val igraph = InstanceKeyGraph( + parse( + """circuit m: + | module c: + | skip + | module m: + | inst i of c + |""".stripMargin + ) + ) def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] = res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}") private def lower( - n: String, - tpe: String, - module: String, - namespace: Set[String], - renames: RenameMap = RenameMap() + n: String, + tpe: String, + module: String, + namespace: Set[String], + otherRenames: RenameMap = RenameMap() ): Lower = { val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe)) val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace - val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set()) + val instRenames = scala.collection.mutable.ListBuffer[(Instance, Instance)]() + val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, instRenames, Set()) + val instMap = Map(OfModule("m") -> instRenames.toMap) + val renames = RenameMap.fromInstanceRenames(igraph, instMap).andThen(otherRenames) Lower(newInstance, resultToFieldSeq(res), renames) } private def get(l: Lower, m: IsMember): Set[IsMember] = l.renameMap.get(m).get.toSet @@ -305,7 +321,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { assert(get(l, i) == Set(i_)) // the ports renaming is also noted - val r = portRenames.andThen(otherRenames) + val r = portRenames.andThen(l.renameMap) assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c"))) diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 29466c72ed..bebeb0bf8e 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -5,6 +5,8 @@ package firrtlTests import firrtl.RenameMap import firrtl.RenameMap.IllegalRenameException import firrtl.annotations._ +import firrtl.annotations.TargetToken.{Instance, OfModule} +import firrtl.analyses.InstanceKeyGraph import firrtl.testutils._ class RenameMapSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala new file mode 100644 index 0000000000..38266efeb1 --- /dev/null +++ b/src/test/scala/firrtlTests/annotationTests/TargetUtilsSpec.scala @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtlTests.annotationTests + +import firrtl.analyses.InstanceKeyGraph.InstanceKey +import firrtl.annotations._ +import firrtl.annotations.TargetToken._ +import firrtl.annotations.TargetUtils._ +import firrtl.testutils.FirrtlFlatSpec + +class TargetUtilsSpec extends FirrtlFlatSpec { + + behavior.of("instKeyPathToTarget") + + it should "create a ModuleTarget for the top module" in { + val input = InstanceKey("Top", "Top") :: Nil + val expected = ModuleTarget("Top", "Top") + instKeyPathToTarget(input) should be(expected) + } + + it should "create absolute InstanceTargets" in { + val input = InstanceKey("Top", "Top") :: + InstanceKey("foo", "Foo") :: + InstanceKey("bar", "Bar") :: + Nil + val expected = InstanceTarget("Top", "Top", Seq((Instance("foo"), OfModule("Foo"))), "bar", "Bar") + instKeyPathToTarget(input) should be(expected) + } + + it should "support starting somewhere down the path" in { + val input = InstanceKey("Top", "Top") :: + InstanceKey("foo", "Foo") :: + InstanceKey("bar", "Bar") :: + InstanceKey("fizz", "Fizz") :: + Nil + val expected = InstanceTarget("Top", "Bar", Seq(), "fizz", "Fizz") + instKeyPathToTarget(input, Some("Bar")) should be(expected) + } + + behavior.of("unfoldInstanceTargets") + + it should "return nothing for ModuleTargets" in { + val input = ModuleTarget("Top", "Foo") + unfoldInstanceTargets(input) should be(Iterable()) + } + + it should "return all other InstanceTargets to the same instance" in { + val input = ModuleTarget("Top", "Top").instOf("foo", "Foo").instOf("bar", "Bar").instOf("fizz", "Fizz") + val expected = + input :: + ModuleTarget("Top", "Foo").instOf("bar", "Bar").instOf("fizz", "Fizz") :: + ModuleTarget("Top", "Bar").instOf("fizz", "Fizz") :: + Nil + unfoldInstanceTargets(input) should be(expected) + } +}