diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index 5dafb2d6f5..dbe4c7eb2e 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -313,25 +313,6 @@ class Router(formatOps: FormatOps) { val lambdaExpire = if (lambdaArrow eq null) null else nextNonCommentSameLine(lambdaArrow) - def getSingleLinePolicy = { - val needBreak = close.right match { - case _: T.KwElse => close.rightOwner match { - case p: Term.If => p.thenp eq leftOwner - case _ => false - } - case _: T.KwCatch => close.rightOwner match { - case p: Term.TryClause => p.expr eq leftOwner - case _ => false - } - case _: T.KwFinally => close.rightOwner match { - case p: Term.TryClause => (p.expr eq leftOwner) || - p.catchClause.contains(leftOwner) - case _ => false - } - case _ => false - } - Policy ? needBreak && decideNewlinesOnlyAfterClose(close) - } def getClassicSingleLineDecisionOpt = if (noBreak()) Some(true) else None @@ -379,11 +360,28 @@ class Router(formatOps: FormatOps) { val singleLineSplitOpt = { if (slbParensExclude eq null) None else singleLineDecisionOpt }.map { sld => - val sldPolicy = getSingleLinePolicy + val ownerIfNeedBreakAfterClose = close.right match { + case _: T.KwElse => close.rightOwner match { + case p: Term.If if p.thenp eq leftOwner => Some(p) + case _ => None + } + case _: T.KwCatch => close.rightOwner match { + case p: Term.TryClause if p.expr eq leftOwner => Some(p) + case _ => None + } + case _: T.KwFinally => close.rightOwner match { + case p: Term.TryClause + if (p.expr eq leftOwner) || + p.catchClause.contains(leftOwner) => Some(p) + case _ => None + } + case _ => None + } val expire = leftOwner.parent match { case Some(p: Term.ForYield) - if !sld && sldPolicy.isEmpty && style.newlines.fold && - leftOwner.is[Term.EnumeratorsBlock] => getLast(p) + if !sld && ownerIfNeedBreakAfterClose.isEmpty && + style.newlines.fold && leftOwner.is[Term.EnumeratorsBlock] => + getLast(p) case _ if style.newlines.isBeforeOpenParenCallSite => close case _ => getSlbEndOnLeft(close) } @@ -402,6 +400,22 @@ class Router(formatOps: FormatOps) { else Policy.RelayOnSplit((s, _) => s.isNL)(slbParensPolicy)( Policy.onLeft(close, "BracesToParensFailed") { case _ => Nil }, ) + val sldPolicy = ownerIfNeedBreakAfterClose.map { p => + if (style.newlines.fold) { + val pend = getSlbEndOnLeft(getLast(p)) + def pendSlb(s: Split) = s + .withSingleLine(pend, noSyntaxNL = true, extend = true) + Policy.End <= close ==> + Policy.onRight(close, s"RB-ELSE[${pend.idx}]") { + case Decision(`close`, ss) => + if (ss.exists(_.isNL)) ss + .map(s => if (s.isNL) s else pendSlb(s)) + else ss.flatMap { s => + Seq(pendSlb(s), s.withMod(Newline).withPenalty(1)) + } + } + } else decideNewlinesOnlyAfterClose(close) + } Split(slbMod, 0).withSingleLine( expire, exclude = exclude, diff --git a/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala b/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala index c56adb2a9b..f5d2606bb3 100644 --- a/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala +++ b/scalafmt-tests-community/intellij/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala @@ -13,7 +13,7 @@ abstract class CommunityIntellijScalaSuite(name: String) class CommunityIntellijScala_2024_2_Suite extends CommunityIntellijScalaSuite("intellij-scala-2024.2") { - override protected def totalStatesVisited: Option[Int] = Some(57952816) + override protected def totalStatesVisited: Option[Int] = Some(57957808) override protected def builds = Seq(getBuild( "2024.2.28", @@ -51,7 +51,7 @@ class CommunityIntellijScala_2024_2_Suite class CommunityIntellijScala_2024_3_Suite extends CommunityIntellijScalaSuite("intellij-scala-2024.3") { - override protected def totalStatesVisited: Option[Int] = Some(58169519) + override protected def totalStatesVisited: Option[Int] = Some(58174471) override protected def builds = Seq(getBuild( "2024.3.4", diff --git a/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala b/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala index a9d13fa422..0c9cda424a 100644 --- a/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala +++ b/scalafmt-tests-community/scala2/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala @@ -9,7 +9,7 @@ abstract class CommunityScala2Suite(name: String) class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") { - override protected def totalStatesVisited: Option[Int] = Some(42556059) + override protected def totalStatesVisited: Option[Int] = Some(42560182) override protected def builds = Seq(getBuild("v2.12.20", dialects.Scala212, 1277)) @@ -18,7 +18,7 @@ class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") { class CommunityScala2_13Suite extends CommunityScala2Suite("scala-2.13") { - override protected def totalStatesVisited: Option[Int] = Some(53056214) + override protected def totalStatesVisited: Option[Int] = Some(53060139) override protected def builds = Seq(getBuild("v2.13.14", dialects.Scala213, 1287)) diff --git a/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala b/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala index cb631676b3..585758e54f 100644 --- a/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala +++ b/scalafmt-tests-community/scala3/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala @@ -9,7 +9,7 @@ abstract class CommunityScala3Suite(name: String) class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") { - override protected def totalStatesVisited: Option[Int] = Some(39245403) + override protected def totalStatesVisited: Option[Int] = Some(39246875) override protected def builds = Seq(getBuild("3.2.2", dialects.Scala32, 791)) @@ -17,7 +17,7 @@ class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") { class CommunityScala3_3Suite extends CommunityScala3Suite("scala-3.3") { - override protected def totalStatesVisited: Option[Int] = Some(42431405) + override protected def totalStatesVisited: Option[Int] = Some(42432863) override protected def builds = Seq(getBuild("3.3.3", dialects.Scala33, 861)) diff --git a/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala b/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala index 2de84f6f00..ac6f72236b 100644 --- a/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala +++ b/scalafmt-tests-community/spark/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala @@ -9,7 +9,7 @@ abstract class CommunitySparkSuite(name: String) class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") { - override protected def totalStatesVisited: Option[Int] = Some(86350328) + override protected def totalStatesVisited: Option[Int] = Some(86399848) override protected def builds = Seq(getBuild("v3.4.1", dialects.Scala213, 2585)) @@ -18,7 +18,7 @@ class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") { class CommunitySpark3_5Suite extends CommunitySparkSuite("spark-3.5") { - override protected def totalStatesVisited: Option[Int] = Some(91352348) + override protected def totalStatesVisited: Option[Int] = Some(91404596) override protected def builds = Seq(getBuild("v3.5.3", dialects.Scala213, 2756)) diff --git a/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat b/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat index 058dc55a42..90997e952d 100644 --- a/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat @@ -16,8 +16,7 @@ println("bbb") } else c } >>> object a { - if (a) { println("bbb") } - else c + if (a) { println("bbb") } else c } <<< 1.3: block, if, non-egyptian curlies object a { @@ -31,8 +30,7 @@ c }} >>> object a { - if (a) { println("bbb") } - else { c } + if (a) { println("bbb") } else { c } } <<< 1.4: block #1043 object a { @@ -252,12 +250,8 @@ object a { >>> object a { val ok1 = if (a > 10) Some(a) else None - val ok2 = - if (a > 10) { Some(a) } - else { None } - val ok3 = - if (aaaa > 10000) { Some(aaaa) } - else { None } + val ok2 = if (a > 10) { Some(a) } else { None } + val ok3 = if (aaaa > 10000) { Some(aaaa) } else { None } } <<< 2.7 #1747: one line without else maxColumn = 60 @@ -296,12 +290,8 @@ object a { >>> object a { val ok1 = if (a > 10) Some(a) else None - val ok2 = - if (a > 10) { Some(a) } - else { None } - val ok3 = - if (aaaa > 10000) { Some(aaaa) } - else { None } + val ok2 = if (a > 10) { Some(a) } else { None } + val ok3 = if (aaaa > 10000) { Some(aaaa) } else { None } } <<< 2.9 #1747: split on else maxColumn = 60 @@ -319,12 +309,8 @@ object a { >>> object a { val ok1 = if (a > 10) Some(a) else None - val ok2 = - if (a > 10) { Some(a) } - else { None } - val ok3 = - if (aaaa > 10000) { Some(aaaa) } - else { None } + val ok2 = if (a > 10) { Some(a) } else { None } + val ok3 = if (aaaa > 10000) { Some(aaaa) } else { None } } <<< 2.10 #1747: split on then without else maxColumn = 60 @@ -357,12 +343,8 @@ object a { >>> object a { val ok1 = if (a > 10) Some(a) else None - val ok2 = - if (a > 10) { Some(a) } - else { None } - val ok3 = - if (aaaa > 10000) { Some(aaaa) } - else { None } + val ok2 = if (a > 10) { Some(a) } else { None } + val ok3 = if (aaaa > 10000) { Some(aaaa) } else { None } } <<< 2.12 #1747: split on if without else maxColumn = 60 @@ -6008,8 +5990,7 @@ class Foo { >>> class Foo { val foo = - if (true) { "" } - else { "a" } + if (true) { "" } else { "a" } val foo = if (true) "" else "a" val foo = if (true) "" else "a" val foo = @@ -6035,8 +6016,7 @@ class Foo() { >>> class Foo() { def ok: Boolean = - if (1 == 1) { true } - else false + if (1 == 1) { true } else false def notOK: Boolean = if (1 == 1) { @@ -10510,8 +10490,7 @@ object a { } >>> object a { - if (s.costWithPenalty > 0) { preFork = false; false } - else true + if (s.costWithPenalty > 0) { preFork = false; false } else true if (s.costWithPenalty <= 0) true else { preFork = false; false } } diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat index 1ad57b477c..e49f5e5ed9 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat @@ -3606,8 +3606,7 @@ object a { } >>> object a { - if ({ !Yield(2); false }) { !Yield(3) } - else { !Yield(4) } + if ({ !Yield(2); false }) { !Yield(3) } else { !Yield(4) } } <<< #3005 if-then object a { @@ -3620,8 +3619,7 @@ object a { } >>> object a { - if ({ !Yield(2); false }) then { !Yield(3) } - else { !Yield(4) } + if ({ !Yield(2); false }) then { !Yield(3) } else { !Yield(4) } } <<< #3005 while object a { @@ -4107,9 +4105,7 @@ class Foo() { } >>> class Foo() { - def ok: Boolean = - if (1 == 1) { true } - else false + def ok: Boolean = if (1 == 1) { true } else false def notOK: Boolean = if (1 == 1) { diff --git a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala index e2d2ec65be..1452813ab0 100644 --- a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { val explored = Debug.explored.get() logger.debug(s"Total explored: $explored") if (!onlyUnit && !onlyManual) - assertEquals(explored, 1086266, "total explored") + assertEquals(explored, 1085954, "total explored") val results = debugResults.result() // TODO(olafur) don't block printing out test results. // I don't want to deal with scalaz's Tasks :'(