Skip to content

Commit

Permalink
AvoidInfix: process scala3 match operator, too
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Nov 20, 2024
1 parent 437d693 commit 14821e3
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ object AvoidInfix extends RewriteFactory {
class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {

private val cfg = ctx.style.rewrite.avoidInfix
private val allowMatchAsOperator = dialect.allowMatchAsOperator

// In a perfect world, we could just use
// Tree.transform {
Expand All @@ -30,18 +31,32 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {
// we will do these dangerous rewritings by hand.

override def rewrite(tree: Tree): Unit = tree match {
case x: Term.ApplyInfix => rewriteImpl(x.lhs, x.op, x.arg, x.targClause)
case x: Term.ApplyInfix =>
rewriteImpl(x.lhs, Right(x.op), x.arg, x.targClause)
case x: Term.Select if !cfg.excludePostfix && noDot(x.name.tokens.head) =>
rewriteImpl(x.qual, x.name)
rewriteImpl(x.qual, Right(x.name))
case x: Term.Match => noDotMatch(x)
.foreach(op => rewriteImpl(x.expr, Left(op), null))
case _ =>
}

private def noDot(opToken: T): Boolean =
!ctx.tokenTraverser.prevNonTrivialToken(opToken).forall(_.is[T.Dot])

private def noDotMatch(t: Term.Match): Either[Boolean, T] =
if (allowMatchAsOperator && t.mods.isEmpty)
ctx.tokenTraverser.prevNonTrivialToken(t.casesBlock.tokens.head) match {
case Some(kw) =>
if (!noDot(kw)) Left(true)
else if (cfg.excludeMatch) Left(false)
else Right(kw)
case _ => Left(false)
}
else Left(false)

private def rewriteImpl(
lhs: Term,
op: Name,
op: Either[T, Name],
rhs: Tree = null,
targs: Member.SyntaxValuesClause = null,
): Unit = {
Expand All @@ -52,19 +67,24 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {
val lhsIsOK = lhsIsWrapped ||
(lhs match {
case t: Term.ApplyInfix => checkMatchingInfix(t.lhs, t.op.value, t.arg)
case t: Term.Match => noDotMatch(t) match {
case Left(ok) => ok
case Right(kw) => checkMatchingInfix(t.expr, kw.text, t.casesBlock)
}
case _ => false
})

if (!checkMatchingInfix(lhs, op.value, rhs, Some(lhsIsOK))) return
if (!checkMatchingInfix(lhs, op.fold(_.text, _.value), rhs, Some(lhsIsOK)))
return
if (!ctx.dialect.allowTryWithAnyExpr)
if (beforeLhsHead.exists(_.is[T.KwTry])) return

val builder = Seq.newBuilder[TokenPatch]

val (opHead, opLast) = ends(op)
val (opHead, opLast) = op.fold((_, null), ends)
builder += TokenPatch.AddLeft(opHead, ".", keepTok = true)

if (rhs ne null) {
if ((rhs ne null) && (opLast ne null)) {
def moveOpenDelim(prev: T, open: T): Unit = {
// move delimiter (before comment or newline)
builder += TokenPatch.AddRight(prev, open.text, keepTok = true)
Expand Down Expand Up @@ -94,7 +114,8 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {
case _: Term.ApplyInfix | _: Term.Match => !lhsIsOK
// foo _ compose bar => (foo _).compose(bar)
// new Foo compose bar => (new Foo).compose(bar)
case _: Term.Eta | _: Term.New => true
case _: Term.Eta | _: Term.New | _: Term.Annotate => true
case t: Term.Select if rhs eq null => !noDot(t.name.tokens.head)
case _ => false
})
if (shouldWrapLhs) {
Expand Down Expand Up @@ -124,6 +145,15 @@ class AvoidInfix(implicit ctx: RewriteCtx) extends RewriteSession {
case None => isWrapped(lhs) ||
checkMatchingInfix(lhs.lhs, lhs.op.value, lhs.arg)
}
case lhs: Term.Match if hasPlaceholder(lhs, includeArg = true) =>
lhsIsOK match {
case Some(x) => x
case None if isWrapped(lhs) => true
case None => noDotMatch(lhs) match {
case Left(ok) => ok
case Right(op) => checkMatchingInfix(lhs.expr, op.text, null)
}
}
case _ => true
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7587,8 +7587,9 @@ object a {
}
>>>
object a {
a.b(c) match
case _ =>
a.b(c)
.match
case _ =>
}
<<< AvoidInfix with match within applyinfix, excludeMatch
rewrite.rules = [AvoidInfix]
Expand Down Expand Up @@ -7616,9 +7617,10 @@ object a {
}
>>>
object a {
(a.b(c) match {
case _ =>
}).d(e)
a.b(c)
.match { case _ =>
}
.d(e)
}
<<< AvoidInfix with match, with dot, excludeMatch
rewrite.rules = [AvoidInfix]
Expand Down Expand Up @@ -7671,7 +7673,7 @@ object a {
}
>>>
object a {
a(b): @c match
(a(b): @c).match
case _ =>
}
<<< match with dot, trailing case comment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7298,7 +7298,7 @@ object a {
}
>>>
object a {
a.b(c) match
a.b(c).match
case _ =>
}
<<< AvoidInfix with match within applyinfix, excludeMatch
Expand All @@ -7325,7 +7325,7 @@ object a {
}
>>>
object a {
(a.b(c) match { case _ => }).d(e)
a.b(c).match { case _ => }.d(e)
}
<<< AvoidInfix with match, with dot, excludeMatch
rewrite.rules = [AvoidInfix]
Expand Down Expand Up @@ -7376,7 +7376,7 @@ object a {
}
>>>
object a {
a(b): @c match
(a(b): @c).match
case _ =>
}
<<< match with dot, trailing case comment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7615,7 +7615,7 @@ object a {
}
>>>
object a {
a.b(c) match
a.b(c).match
case _ =>
}
<<< AvoidInfix with match within applyinfix, excludeMatch
Expand Down Expand Up @@ -7644,9 +7644,9 @@ object a {
}
>>>
object a {
(a.b(c) match {
a.b(c).match {
case _ =>
}).d(e)
}.d(e)
}
<<< AvoidInfix with match, with dot, excludeMatch
rewrite.rules = [AvoidInfix]
Expand Down Expand Up @@ -7697,7 +7697,7 @@ object a {
}
>>>
object a {
a(b): @c match
(a(b): @c).match
case _ =>
}
<<< match with dot, trailing case comment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7902,8 +7902,9 @@ object a {
}
>>>
object a {
a.b(c) match
case _ =>
a.b(c)
.match
case _ =>
}
<<< AvoidInfix with match within applyinfix, excludeMatch
rewrite.rules = [AvoidInfix]
Expand Down Expand Up @@ -7933,11 +7934,10 @@ object a {
}
>>>
object a {
(
a.b(c) match {
case _ =>
a.b(c)
.match { case _ =>
}
).d(e)
.d(e)
}
<<< AvoidInfix with match, with dot, excludeMatch
rewrite.rules = [AvoidInfix]
Expand Down Expand Up @@ -7988,7 +7988,7 @@ object a {
}
>>>
object a {
a(b): @c match
(a(b): @c).match
case _ =>
}
<<< match with dot, trailing case comment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions {
val explored = Debug.explored.get()
logger.debug(s"Total explored: $explored")
if (!onlyUnit && !onlyManual)
assertEquals(explored, 1084554, "total explored")
assertEquals(explored, 1084630, "total explored")
val results = debugResults.result()
// TODO(olafur) don't block printing out test results.
// I don't want to deal with scalaz's Tasks :'(
Expand Down

0 comments on commit 14821e3

Please sign in to comment.