diff --git a/src/main/scala/firrtl/transforms/MustDedup.scala b/src/main/scala/firrtl/transforms/MustDedup.scala index 3e7629cda7..cce6764325 100644 --- a/src/main/scala/firrtl/transforms/MustDedup.scala +++ b/src/main/scala/firrtl/transforms/MustDedup.scala @@ -14,18 +14,21 @@ import firrtl.graph.DiGraph import java.io.{File, FileWriter} /** Marks modules as "must deduplicate" */ -case class MustDeduplicateAnnotation(modules: Seq[IsModule]) extends MultiTargetAnnotation { - def targets: Seq[Seq[IsModule]] = modules.map(Seq(_)) - - def duplicate(n: Seq[Seq[Target]]): MustDeduplicateAnnotation = { - val newModules = n.map { - case Seq(mod: IsModule) => mod - case _ => - val msg = "Something went wrong! This anno should only rename to single IsModules! " + - s"Got: $modules -> $n" - throw new Exception(msg) +case class MustDeduplicateAnnotation(modules: Seq[IsModule]) extends Annotation { + + def update(renames: RenameMap): Seq[MustDeduplicateAnnotation] = { + val newModules: Seq[IsModule] = modules.flatMap { m => + renames.get(m) match { + case None => Seq(m) + case Some(Seq()) => Seq() + case Some(Seq(one: IsModule)) => Seq(one) + case Some(many) => + val msg = "Something went wrong! This anno's targets should only rename to IsModules! " + + s"Got: ${m.serialize} -> ${many.map(_.serialize).mkString(", ")}" + throw new Exception(msg) + } } - MustDeduplicateAnnotation(newModules) + if (newModules.isEmpty) Seq() else Seq(this.copy(newModules)) } } diff --git a/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala b/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala index 2f633e0e52..5d25321eb7 100644 --- a/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala +++ b/src/test/scala/firrtlTests/transforms/MustDedupSpec.scala @@ -264,4 +264,38 @@ class MustDedupSpec extends AnyFeatureSpec with FirrtlMatchers with GivenWhenThe (new firrtl.stage.FirrtlPhase).transform(annos) } } + + Feature("When you have unused modules that should dedup, and they do") { + val text = """ + |circuit A : + | module B : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module B_1 : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module A : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst b of B + | inst b_1 of B_1 + | b.io.in <= io.in + | b_1.io.in <= io.in + | io.out <= and(io.in, UInt(123)) + """.stripMargin + val top = CircuitTarget("A") + val bdedup = MustDeduplicateAnnotation(Seq(top.module("B"), top.module("B_1"))) + + Scenario("MustDeduplicateAnnotation should be deleted gracefully") { + val testDir = createTestDirectory("must_dedup") + val annos = Seq( + TargetDirAnnotation(testDir.toString), + FirrtlSourceAnnotation(text), + RunFirrtlTransformAnnotation(new MustDeduplicateTransform), + bdedup + ) + + val resAnnos = (new firrtl.stage.FirrtlPhase).transform(annos) + resAnnos.collectFirst { case a: MustDeduplicateTransform => a } should be(None) + } + } }