Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Commit

Permalink
Ensure InlineCasts does not inline complex Expressions (#2130)
Browse files Browse the repository at this point in the history
Previously, InlineCasts could inline complex (ie. non-cast) Expressions
into other complex Expressions. Now it will only inline so long as there
no more than 1 complex Expression in the current nested Expression.

Co-authored-by: Albert Magyar <[email protected]>
  • Loading branch information
jackkoenig and albert-magyar authored Mar 19, 2021
1 parent 94d1bee commit b274b31
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 29 deletions.
41 changes: 24 additions & 17 deletions src/main/scala/firrtl/transforms/InlineCasts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,30 @@ object InlineCastsTransform {
* @param expr the Expression being transformed
* @return Returns expr with [[WRef]]s replaced by values found in replace
*/
def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match {
// Anything that may generate a part-select should not be inlined!
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
case e =>
e.map(onExpr(replace)) match {
case e @ WRef(name, _, _, _) =>
replace
.get(name)
.filter(isSimpleCast(castSeen = false))
.getOrElse(e)
case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
replace
.get(name)
.map(value => e.copy(args = Seq(value)))
.getOrElse(e)
case other => other // Not a candidate
}
def onExpr(replace: NodeMap)(expr: Expression): Expression = {
// Keep track if we've seen any non-cast expressions while recursing
def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match {
// Skip pads to avoid inlining literals into pads which results in invalid Verilog
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
case e =>
e.map(rec(hasNonCastParent || !isCast(e))) match {
case e @ WRef(name, _, _, _) =>
replace
.get(name)
.filter(isSimpleCast(castSeen = false))
.getOrElse(e)
case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
replace
.get(name)
// Only inline the Expression if there is no non-cast parent in the expression tree OR
// if the subtree contains only casts and references.
.filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x))
.map(value => e.copy(args = Seq(value)))
.getOrElse(e)
case other => other // Not a candidate
}
}
rec(false)(expr)
}

/** Inline casts in a Statement
Expand Down
8 changes: 6 additions & 2 deletions src/test/scala/firrtl/testutils/FirrtlSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,14 @@ trait FirrtlRunners extends BackendCompilationUtilities {

/** Compiles input Firrtl to Verilog */
def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = {
compileToVerilogCircuitState(input, annotations).getEmittedCircuit.value
}

/** Compiles input Firrtl to Verilog */
def compileToVerilogCircuitState(input: String, annotations: AnnotationSeq = Seq.empty): CircuitState = {
val circuit = Parser.parse(input.split("\n").toIterator)
val compiler = new VerilogCompiler
val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
res.getEmittedCircuit.value
compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
}

/** Compile a Firrtl file
Expand Down
68 changes: 58 additions & 10 deletions src/test/scala/firrtlTests/InlineCastsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ package firrtlTests

import firrtl.transforms.InlineCastsTransform
import firrtl.testutils.FirrtlFlatSpec
import firrtl.testutils.FirrtlCheckers._

/*
* Note: InlineCasts is still part of mverilog, so this test must both:
* - Test that the InlineCasts fix is effective given the current mverilog
* - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog
*
* This is why the test passes InlineCasts as a custom transform: to future-proof it so that
* it can do real LEC against no-InlineCasts. It currently is just a sanity check that the
* emitted Verilog is legal, but it will automatically become a more meaningful test when
* InlineCasts is not run in mverilog.
*/
class InlineCastsEquivalenceSpec extends FirrtlFlatSpec {
/*
* Note: InlineCasts is still part of mverilog, so this test must both:
* - Test that the InlineCasts fix is effective given the current mverilog
* - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog
*
* This is why the test passes InlineCasts as a custom transform: to future-proof it so that
* it can do real LEC against no-InlineCasts. It currently is just a sanity check that the
* emitted Verilog is legal, but it will automatically become a more meaningful test when
* InlineCasts is not run in mverilog.
*/
"InlineCastsTransform" should "not produce broken Verilog" in {
val input =
s"""circuit literalsel_fir:
Expand All @@ -26,4 +27,51 @@ class InlineCastsEquivalenceSpec extends FirrtlFlatSpec {
|""".stripMargin
firrtlEquivalenceTest(input, Seq(new InlineCastsTransform))
}

it should "not inline complex expressions into other complex expressions" in {
val input =
"""circuit NeverInlineComplexIntoComplex :
| module NeverInlineComplexIntoComplex :
| input a : SInt<3>
| input b : UInt<2>
| input c : UInt<2>
| input sel : UInt<1>
| output out : SInt<3>
| node diff = sub(b, c)
| out <= mux(sel, a, asSInt(diff))
|""".stripMargin
val expected =
"""module NeverInlineComplexIntoComplexRef(
| input [2:0] a,
| input [1:0] b,
| input [1:0] c,
| input sel,
| output [2:0] out
|);
| wire [2:0] diff = b - c;
| assign out = sel ? $signed(a) : $signed(diff);
|endmodule
|""".stripMargin
firrtlEquivalenceWithVerilog(input, expected)
}

it should "inline casts on both sides of a more complex expression" in {
val input =
"""circuit test :
| module test :
| input clock : Clock
| input in : UInt<8>
| output out : UInt<8>
|
| node _T_1 = asUInt(clock)
| node _T_2 = not(_T_1)
| node clock_n = asClock(_T_2)
| reg r : UInt<8>, clock_n
| r <= in
| out <= r
|""".stripMargin
val verilog = compileToVerilogCircuitState(input)
verilog should containLine("always @(posedge clock_n) begin")

}
}

0 comments on commit b274b31

Please sign in to comment.