diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala index 04da7e6fd4..a42e26064e 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/AvoidInfix.scala @@ -21,6 +21,7 @@ object AvoidInfix extends RewriteFactory { class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { private val cfg = ctx.style.rewrite.avoidInfix + private val allowMatchAsOperator = dialect.allowMatchAsOperator // In a perfect world, we could just use // Tree.transform { @@ -30,18 +31,32 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { // we will do these dangerous rewritings by hand. override def rewrite(tree: Tree): Unit = tree match { - case x: Term.ApplyInfix => rewriteImpl(x.lhs, x.op, x.arg, x.targClause) + case x: Term.ApplyInfix => + rewriteImpl(x.lhs, Right(x.op), x.arg, x.targClause) case x: Term.Select if !cfg.excludePostfix && noDot(x.name.tokens.head) => - rewriteImpl(x.qual, x.name) + rewriteImpl(x.qual, Right(x.name)) + case x: Term.Match => noDotMatch(x) + .foreach(op => rewriteImpl(x.expr, Left(op), null)) case _ => } private def noDot(opToken: T): Boolean = !ctx.tokenTraverser.prevNonTrivialToken(opToken).forall(_.is[T.Dot]) + private def noDotMatch(t: Term.Match): Either[Boolean, T] = + if (allowMatchAsOperator && t.mods.isEmpty) + ctx.tokenTraverser.prevNonTrivialToken(t.casesBlock.tokens.head) match { + case Some(kw) => + if (!noDot(kw)) Left(true) + else if (cfg.excludeMatch) Left(false) + else Right(kw) + case _ => Left(false) + } + else Left(false) + private def rewriteImpl( lhs: Term, - op: Name, + op: Either[T, Name], rhs: Tree = null, targs: Member.SyntaxValuesClause = null, ): Unit = { @@ -52,19 +67,24 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { val lhsIsOK = lhsIsWrapped || (lhs match { case t: Term.ApplyInfix => checkMatchingInfix(t.lhs, t.op.value, t.arg) + case t: Term.Match => noDotMatch(t) match { + case Left(ok) => ok + case Right(kw) => checkMatchingInfix(t.expr, kw.text, t.casesBlock) + } case _ => false }) - if (!checkMatchingInfix(lhs, op.value, rhs, Some(lhsIsOK))) return + if (!checkMatchingInfix(lhs, op.fold(_.text, _.value), rhs, Some(lhsIsOK))) + return if (!ctx.dialect.allowTryWithAnyExpr) if (beforeLhsHead.exists(_.is[T.KwTry])) return val builder = Seq.newBuilder[TokenPatch] - val (opHead, opLast) = ends(op) + val (opHead, opLast) = op.fold((_, null), ends) builder += TokenPatch.AddLeft(opHead, ".", keepTok = true) - if (rhs ne null) { + if ((rhs ne null) && (opLast ne null)) { def moveOpenDelim(prev: T, open: T): Unit = { // move delimiter (before comment or newline) builder += TokenPatch.AddRight(prev, open.text, keepTok = true) @@ -94,7 +114,8 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { case _: Term.ApplyInfix | _: Term.Match => !lhsIsOK // foo _ compose bar => (foo _).compose(bar) // new Foo compose bar => (new Foo).compose(bar) - case _: Term.Eta | _: Term.New => true + case _: Term.Eta | _: Term.New | _: Term.Annotate => true + case t: Term.Select if rhs eq null => !noDot(t.name.tokens.head) case _ => false }) if (shouldWrapLhs) { @@ -124,6 +145,15 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession { case None => isWrapped(lhs) || checkMatchingInfix(lhs.lhs, lhs.op.value, lhs.arg) } + case lhs: Term.Match if hasPlaceholder(lhs, includeArg = true) => + lhsIsOK match { + case Some(x) => x + case None if isWrapped(lhs) => true + case None => noDotMatch(lhs) match { + case Left(ok) => ok + case Right(op) => checkMatchingInfix(lhs.expr, op.text, null) + } + } case _ => true }) diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat index a4fc2f7fc4..703f0ad39e 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces.stat @@ -7587,8 +7587,9 @@ object a { } >>> object a { - a.b(c) match - case _ => + a.b(c) + .match + case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch rewrite.rules = [AvoidInfix] @@ -7616,9 +7617,10 @@ object a { } >>> object a { - (a.b(c) match { - case _ => - }).d(e) + a.b(c) + .match { case _ => + } + .d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] @@ -7671,7 +7673,7 @@ object a { } >>> object a { - a(b): @c match + (a(b): @c).match case _ => } <<< match with dot, trailing case comment diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat index 1fe09931df..1ad57b477c 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_fold.stat @@ -7298,7 +7298,7 @@ object a { } >>> object a { - a.b(c) match + a.b(c).match case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch @@ -7325,7 +7325,7 @@ object a { } >>> object a { - (a.b(c) match { case _ => }).d(e) + a.b(c).match { case _ => }.d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] @@ -7376,7 +7376,7 @@ object a { } >>> object a { - a(b): @c match + (a(b): @c).match case _ => } <<< match with dot, trailing case comment diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat index 4e4a4829c5..606a66c147 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_keep.stat @@ -7615,7 +7615,7 @@ object a { } >>> object a { - a.b(c) match + a.b(c).match case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch @@ -7644,9 +7644,9 @@ object a { } >>> object a { - (a.b(c) match { + a.b(c).match { case _ => - }).d(e) + }.d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] @@ -7697,7 +7697,7 @@ object a { } >>> object a { - a(b): @c match + (a(b): @c).match case _ => } <<< match with dot, trailing case comment diff --git a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat index aad697be7b..f553bec717 100644 --- a/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat +++ b/scalafmt-tests/shared/src/test/resources/scala3/OptionalBraces_unfold.stat @@ -7902,8 +7902,9 @@ object a { } >>> object a { - a.b(c) match - case _ => + a.b(c) + .match + case _ => } <<< AvoidInfix with match within applyinfix, excludeMatch rewrite.rules = [AvoidInfix] @@ -7933,11 +7934,10 @@ object a { } >>> object a { - ( - a.b(c) match { - case _ => + a.b(c) + .match { case _ => } - ).d(e) + .d(e) } <<< AvoidInfix with match, with dot, excludeMatch rewrite.rules = [AvoidInfix] @@ -7988,7 +7988,7 @@ object a { } >>> object a { - a(b): @c match + (a(b): @c).match case _ => } <<< match with dot, trailing case comment diff --git a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala index 27176d530b..168719666e 100644 --- a/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala +++ b/scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions { val explored = Debug.explored.get() logger.debug(s"Total explored: $explored") if (!onlyUnit && !onlyManual) - assertEquals(explored, 1084554, "total explored") + assertEquals(explored, 1084630, "total explored") val results = debugResults.result() // TODO(olafur) don't block printing out test results. // I don't want to deal with scalaz's Tasks :'(