Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Fix RemoveAccesses, delete CSESubAccesses
Browse files Browse the repository at this point in the history
CSESubAccesses was intended to be a simple workaround for a quadratic
performance bug in RemoveAccesses but ended up having tricky corner
cases and was hard to get right. The solution to the RemoveAccesses
bug--quadratic expansion of dynamic indexes of vecs of aggreate
type--turned out to be quite simple and makes CSESubAccesses much less
useful and not worth fixing.
  • Loading branch information
jackkoenig committed Mar 26, 2021
1 parent 67ce97a commit 549992e
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 428 deletions.
63 changes: 35 additions & 28 deletions src/main/scala/firrtl/passes/RemoveAccesses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import firrtl.Mappers._
import firrtl.Utils._
import firrtl.WrappedExpression._
import firrtl.options.Dependency
import firrtl.transforms.CSESubAccesses

import scala.collection.mutable

Expand All @@ -22,8 +21,7 @@ object RemoveAccesses extends Pass {
Dependency(PullMuxes),
Dependency(ZeroLengthVecs),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects),
Dependency[CSESubAccesses]
Dependency(ExpandConnects)
) ++ firrtl.stage.Forms.Deduped

override def invalidates(a: Transform): Boolean = a match {
Expand Down Expand Up @@ -122,26 +120,26 @@ object RemoveAccesses extends Pass {
/** Replaces a subaccess in a given source expression
*/
val stmts = mutable.ArrayBuffer[Statement]()
def removeSource(e: Expression): Expression = e match {
case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) =>
val rs = getLocations(e)
rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
stmts += IsInvalid(get_info(s), getTemp(i))
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
case _ => e
// Only called on RefLikes that definitely have a SubAccess
// Must accept Expression because that's the output type of fixIndices
def removeSource(e: Expression): Expression = {
val rs = getLocations(e)
rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
stmts += IsInvalid(get_info(s), getTemp(i))
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
}

/** Replaces a subaccess in a given sink expression
Expand All @@ -162,14 +160,23 @@ object RemoveAccesses extends Pass {
case _ => loc
}

/** Recurse until find SubAccess and call fixSource on its index
* @note this only accepts [[RefLikeExpression]]s but we can't enforce it because map
* requires Expression => Expression
*/
def fixIndices(e: Expression): Expression = e match {
case e: SubAccess => e.copy(index = fixSource(e.index))
case other => other.map(fixIndices)
}

/** Recursively walks a source expression and fixes all subaccesses
* If we see a sub-access, replace it.
* Otherwise, map to children.
*
* If we see a RefLikeExpression that contains a SubAccess, we recursively remove
* subaccesses from the indices of any SubAccesses, then process modified RefLikeExpression
*/
def fixSource(e: Expression): Expression = e match {
case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow))
//case w: WSubIndex => removeSource(w)
//case w: WSubField => removeSource(w)
case ref: RefLikeExpression =>
if (hasAccess(ref)) removeSource(fixIndices(ref)) else ref
case x => x.map(fixSource)
}

Expand Down
1 change: 0 additions & 1 deletion src/main/scala/firrtl/stage/Forms.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ object Forms {
val MidForm: Seq[TransformDependency] = HighForm ++
Seq(
Dependency(passes.PullMuxes),
Dependency[firrtl.transforms.CSESubAccesses],
Dependency(passes.ReplaceAccesses),
Dependency(passes.ExpandConnects),
Dependency(passes.RemoveAccesses),
Expand Down
168 changes: 0 additions & 168 deletions src/main/scala/firrtl/transforms/CSESubAccesses.scala

This file was deleted.

3 changes: 1 addition & 2 deletions src/test/scala/firrtlTests/LowerTypesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| out <= in0[in1[in2[0]]][in1[in2[1]]]
|""".stripMargin
val expected = Seq(
"node _in0_in1_in1 = _in0_in1_in1_in2_1",
"out <= _in0_in1_in1"
"out <= _in0_in1_in1_in2_1"
)

executeTest(input, expected)
Expand Down
1 change: 0 additions & 1 deletion src/test/scala/firrtlTests/LoweringCompilersSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.MidForm, Forms.Deduped)
val patches = Seq(
Add(2, Seq(Dependency[firrtl.transforms.CSESubAccesses])),
Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))),
// Uniquify is now part of [[firrtl.passes.LowerTypes]]
Expand Down
10 changes: 5 additions & 5 deletions src/test/scala/firrtlTests/UnitTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,14 @@ class UnitTests extends FirrtlFlatSpec {
//TODO(azidar): I realize this is brittle, but unfortunately there
// isn't a better way to test this pass
val check = Seq(
"""wire _table_1 : { a : UInt<8>}""",
"""_table_1.a is invalid""",
"""wire _table_1_a : UInt<8>""",
"""_table_1_a is invalid""",
"""when UInt<1>("h1") :""",
"""_table_1.a <= table[1].a""",
"""_table_1_a <= table[1].a""",
"""wire _otherTable_table_1_a_a : UInt<8>""",
"""when eq(UInt<1>("h0"), _table_1.a) :""",
"""when eq(UInt<1>("h0"), _table_1_a) :""",
"""otherTable[0].a <= _otherTable_table_1_a_a""",
"""when eq(UInt<1>("h1"), _table_1.a) :""",
"""when eq(UInt<1>("h1"), _table_1_a) :""",
"""otherTable[1].a <= _otherTable_table_1_a_a""",
"""_otherTable_table_1_a_a <= UInt<1>("h0")"""
)
Expand Down
Loading

0 comments on commit 549992e

Please sign in to comment.