From 3b16caf07d777e297f30c8220d678ef30dd8b752 Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Thu, 18 Mar 2021 23:31:51 -0700 Subject: [PATCH] Ensure InlineCasts does not inline complex Expressions (#2130) Previously, InlineCasts could inline complex (ie. non-cast) Expressions into other complex Expressions. Now it will only inline so long as there no more than 1 complex Expression in the current nested Expression. Co-authored-by: Albert Magyar (cherry picked from commit b274b319d4a4014c154f06bfc174beba461d6fce) # Conflicts: # src/main/scala/firrtl/transforms/InlineCasts.scala --- .../scala/firrtl/transforms/InlineCasts.scala | 27 ++++++++ .../scala/firrtl/testutils/FirrtlSpec.scala | 8 ++- .../scala/firrtlTests/InlineCastsSpec.scala | 68 ++++++++++++++++--- 3 files changed, 91 insertions(+), 12 deletions(-) diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index 3dac938e6d..59a4ce1ee2 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -28,6 +28,7 @@ object InlineCastsTransform { * @param expr the Expression being transformed * @return Returns expr with [[WRef]]s replaced by values found in replace */ +<<<<<<< HEAD def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match { // Anything that may generate a part-select should not be inlined! case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr @@ -42,6 +43,32 @@ object InlineCastsTransform { .getOrElse(e) case other => other // Not a candidate } +======= + def onExpr(replace: NodeMap)(expr: Expression): Expression = { + // Keep track if we've seen any non-cast expressions while recursing + def rec(hasNonCastParent: Boolean)(expr: Expression): Expression = expr match { + // Skip pads to avoid inlining literals into pads which results in invalid Verilog + case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr + case e => + e.map(rec(hasNonCastParent || !isCast(e))) match { + case e @ WRef(name, _, _, _) => + replace + .get(name) + .filter(isSimpleCast(castSeen = false)) + .getOrElse(e) + case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => + replace + .get(name) + // Only inline the Expression if there is no non-cast parent in the expression tree OR + // if the subtree contains only casts and references. + .filter(x => !hasNonCastParent || isSimpleCast(castSeen = true)(x)) + .map(value => e.copy(args = Seq(value))) + .getOrElse(e) + case other => other // Not a candidate + } + } + rec(false)(expr) +>>>>>>> b274b319... Ensure InlineCasts does not inline complex Expressions (#2130) } /** Inline casts in a Statement diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala index 8f0241fe99..aaa857f196 100644 --- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala +++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala @@ -117,10 +117,14 @@ trait FirrtlRunners extends BackendCompilationUtilities { /** Compiles input Firrtl to Verilog */ def compileToVerilog(input: String, annotations: AnnotationSeq = Seq.empty): String = { + compileToVerilogCircuitState(input, annotations).getEmittedCircuit.value + } + + /** Compiles input Firrtl to Verilog */ + def compileToVerilogCircuitState(input: String, annotations: AnnotationSeq = Seq.empty): CircuitState = { val circuit = Parser.parse(input.split("\n").toIterator) val compiler = new VerilogCompiler - val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) - res.getEmittedCircuit.value + compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) } /** Compile a Firrtl file * diff --git a/src/test/scala/firrtlTests/InlineCastsSpec.scala b/src/test/scala/firrtlTests/InlineCastsSpec.scala index 1bf36b332a..29366de9b5 100644 --- a/src/test/scala/firrtlTests/InlineCastsSpec.scala +++ b/src/test/scala/firrtlTests/InlineCastsSpec.scala @@ -4,18 +4,19 @@ package firrtlTests import firrtl.transforms.InlineCastsTransform import firrtl.testutils.FirrtlFlatSpec +import firrtl.testutils.FirrtlCheckers._ -/* - * Note: InlineCasts is still part of mverilog, so this test must both: - * - Test that the InlineCasts fix is effective given the current mverilog - * - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog - * - * This is why the test passes InlineCasts as a custom transform: to future-proof it so that - * it can do real LEC against no-InlineCasts. It currently is just a sanity check that the - * emitted Verilog is legal, but it will automatically become a more meaningful test when - * InlineCasts is not run in mverilog. - */ class InlineCastsEquivalenceSpec extends FirrtlFlatSpec { + /* + * Note: InlineCasts is still part of mverilog, so this test must both: + * - Test that the InlineCasts fix is effective given the current mverilog + * - Provide a test that will be robust if and when InlineCasts is no longer run in mverilog + * + * This is why the test passes InlineCasts as a custom transform: to future-proof it so that + * it can do real LEC against no-InlineCasts. It currently is just a sanity check that the + * emitted Verilog is legal, but it will automatically become a more meaningful test when + * InlineCasts is not run in mverilog. + */ "InlineCastsTransform" should "not produce broken Verilog" in { val input = s"""circuit literalsel_fir: @@ -26,4 +27,51 @@ class InlineCastsEquivalenceSpec extends FirrtlFlatSpec { |""".stripMargin firrtlEquivalenceTest(input, Seq(new InlineCastsTransform)) } + + it should "not inline complex expressions into other complex expressions" in { + val input = + """circuit NeverInlineComplexIntoComplex : + | module NeverInlineComplexIntoComplex : + | input a : SInt<3> + | input b : UInt<2> + | input c : UInt<2> + | input sel : UInt<1> + | output out : SInt<3> + | node diff = sub(b, c) + | out <= mux(sel, a, asSInt(diff)) + |""".stripMargin + val expected = + """module NeverInlineComplexIntoComplexRef( + | input [2:0] a, + | input [1:0] b, + | input [1:0] c, + | input sel, + | output [2:0] out + |); + | wire [2:0] diff = b - c; + | assign out = sel ? $signed(a) : $signed(diff); + |endmodule + |""".stripMargin + firrtlEquivalenceWithVerilog(input, expected) + } + + it should "inline casts on both sides of a more complex expression" in { + val input = + """circuit test : + | module test : + | input clock : Clock + | input in : UInt<8> + | output out : UInt<8> + | + | node _T_1 = asUInt(clock) + | node _T_2 = not(_T_1) + | node clock_n = asClock(_T_2) + | reg r : UInt<8>, clock_n + | r <= in + | out <= r + |""".stripMargin + val verilog = compileToVerilogCircuitState(input) + verilog should containLine("always @(posedge clock_n) begin") + + } }