Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Ensure InlineCasts does not inline complex Expressions (bp #2130) #2142

Open
wants to merge 1 commit into
base: 1.2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/main/scala/firrtl/transforms/InlineCasts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ object InlineCastsTransform {
* @param expr the Expression being transformed
* @return Returns expr with [[WRef]]s replaced by values found in replace
*/
<<<<<<< HEAD
def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match {
// Anything that may generate a part-select should not be inlined!
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
Expand All @@ -39,6 +40,32 @@ object InlineCastsTransform {
.getOrElse(e)
case other => other // Not a candidate
}
=======
def onExpr(replace: NodeMap)(expr: Expression): Expression = {
// Keep track if we've seen any non-cast expressions while recursing
def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match {
// Skip pads to avoid inlining literals into pads which results in invalid Verilog
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
case e =>
e.map(rec(hasNonCastParent || !isCast(e))) match {
case e @ WRef(name, _, _, _) =>
replace
.get(name)
.filter(isSimpleCast(castSeen = false))
.getOrElse(e)
case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
replace
.get(name)
// Only inline the Expression if there is no non-cast parent in the expression tree OR
// if the subtree contains only casts and references.
.filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x))
.map(value => e.copy(args = Seq(value)))
.getOrElse(e)
case other => other // Not a candidate
}
}
rec(false)(expr)
>>>>>>> b274b319... Ensure InlineCasts does not inline complex Expressions (#2130)
}

/** Inline casts in a Statement
Expand Down
8 changes: 6 additions & 2 deletions src/test/scala/firrtlTests/FirrtlSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ trait FirrtlRunners extends BackendCompilationUtilities {

/** Compiles input Firrtl to Verilog */
def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = {
compileToVerilogCircuitState(input, annotations).getEmittedCircuit.value
}

/** Compiles input Firrtl to Verilog */
def compileToVerilogCircuitState(input: String, annotations: AnnotationSeq = Seq.empty): CircuitState = {
val circuit = Parser.parse(input.split("\n").toIterator)
val compiler = new VerilogCompiler
val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
res.getEmittedCircuit.value
compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms)
}
/** Compile a Firrtl file
*
Expand Down
72 changes: 62 additions & 10 deletions src/test/scala/firrtlTests/InlineCastsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,24 @@
package firrtlTests

import firrtl.transforms.InlineCastsTransform
<<<<<<< HEAD
import firrtlTests.FirrtlFlatSpec
=======
import firrtl.testutils.FirrtlFlatSpec
import firrtl.testutils.FirrtlCheckers._
>>>>>>> b274b319... Ensure InlineCasts does not inline complex Expressions (#2130)

/*
* Note: InlineCasts is still part of mverilog, so this test must both:
* - Test that the InlineCasts fix is effective given the current mverilog
* - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog
*
* This is why the test passes InlineCasts as a custom transform: to future-proof it so that
* it can do real LEC against no-InlineCasts. It currently is just a sanity check that the
* emitted Verilog is legal, but it will automatically become a more meaningful test when
* InlineCasts is not run in mverilog.
*/
class InlineCastsEquivalenceSpec extends FirrtlFlatSpec {
/*
* Note: InlineCasts is still part of mverilog, so this test must both:
* - Test that the InlineCasts fix is effective given the current mverilog
* - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog
*
* This is why the test passes InlineCasts as a custom transform: to future-proof it so that
* it can do real LEC against no-InlineCasts. It currently is just a sanity check that the
* emitted Verilog is legal, but it will automatically become a more meaningful test when
* InlineCasts is not run in mverilog.
*/
"InlineCastsTransform" should "not produce broken Verilog" in {
val input =
s"""circuit literalsel_fir:
Expand All @@ -26,4 +31,51 @@ class InlineCastsEquivalenceSpec extends FirrtlFlatSpec {
|""".stripMargin
firrtlEquivalenceTest(input, Seq(new InlineCastsTransform))
}

it should "not inline complex expressions into other complex expressions" in {
val input =
"""circuit NeverInlineComplexIntoComplex :
| module NeverInlineComplexIntoComplex :
| input a : SInt<3>
| input b : UInt<2>
| input c : UInt<2>
| input sel : UInt<1>
| output out : SInt<3>
| node diff = sub(b, c)
| out <= mux(sel, a, asSInt(diff))
|""".stripMargin
val expected =
"""module NeverInlineComplexIntoComplexRef(
| input [2:0] a,
| input [1:0] b,
| input [1:0] c,
| input sel,
| output [2:0] out
|);
| wire [2:0] diff = b - c;
| assign out = sel ? $signed(a) : $signed(diff);
|endmodule
|""".stripMargin
firrtlEquivalenceWithVerilog(input, expected)
}

it should "inline casts on both sides of a more complex expression" in {
val input =
"""circuit test :
| module test :
| input clock : Clock
| input in : UInt<8>
| output out : UInt<8>
|
| node _T_1 = asUInt(clock)
| node _T_2 = not(_T_1)
| node clock_n = asClock(_T_2)
| reg r : UInt<8>, clock_n
| r <= in
| out <= r
|""".stripMargin
val verilog = compileToVerilogCircuitState(input)
verilog should containLine("always @(posedge clock_n) begin")

}
}