diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala index f016391cf8..260be8e035 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala @@ -1241,19 +1241,19 @@ class FormatOps( // look for arrow before body, if any, else after params def getFuncArrow(term: Term.FunctionTerm): Option[FT] = tokens .tokenBeforeOpt(term.body) - .orElse(tokenAfterOpt(term.paramClause).map(getArrowAfter)) + .orElse(tokenAfterOpt(term.paramClause).map(getArrowAfter[T.FunctionArrow])) // look for arrow before body, if any, else after cond/pat - def getCaseArrow(term: Case): FT = tokenBeforeOpt(term.body) - .getOrElse(getArrowAfter(tokenAfter(term.cond.getOrElse(term.pat)))) + def getCaseArrow(term: Case): FT = tokenBeforeOpt(term.body).getOrElse { + getArrowAfter[T.RightArrow](tokenAfter(term.cond.getOrElse(term.pat))) + } // look for arrow before body, if any, else after cond/pat def getCaseArrow(term: TypeCase): FT = next(tokenAfter(term.pat)) - private def getArrowAfter(ft: FT): FT = { + private def getArrowAfter[A](ft: FT)(implicit f: Classifier[T, A]): FT = { val maybeArrow = next(ft) - if (maybeArrow.left.is[T.RightArrow]) maybeArrow - else nextAfterNonComment(maybeArrow) + if (f(maybeArrow.left)) maybeArrow else nextAfterNonComment(maybeArrow) } @tailrec