Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pattern matching on case refactor #491

Merged
merged 2 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private[compiletime] trait ExprPromisesPlatform extends ExprPromises { this: Def
protected object PatternMatchCase extends PatternMatchCaseModule {

def matchOn[From: Type, To: Type](src: Expr[From], cases: List[PatternMatchCase[To]]): Expr[To] = {
val casesTrees = cases.map { case PatternMatchCase(someFrom, usage, fromName, _) =>
val casesTrees = cases.map { case PatternMatchCase(someFrom, usage, fromName) =>
import someFrom.Underlying as SomeFrom
val markUsed = Expr.suppressUnused(c.Expr[someFrom.Underlying](q"$fromName"))
cq"""$fromName : $SomeFrom => { $markUsed; $usage }"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ private[compiletime] trait TypesPlatform extends Types { this: DefinitionsPlatfo
// everyone has the same baseClasses, everyone reports to have public primaryConstructor (which is <none>).
// The only different in behavior is that one prints com.my.Enum and another com.my.Enum(MyValue).
val javaEnumRegexpFormat = raw"^(.+)\((.+)\)$$".r
def isJavaEnumValue(tpe: c.Type): Boolean =
tpe.typeSymbol.isJavaEnum && javaEnumRegexpFormat.pattern
.matcher(tpe.toString)
.matches() // 2.12 doesn't have .matches
}

import platformSpecific.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,17 @@ trait ProductTypesPlatform extends ProductTypes { this: DefinitionsPlatform =>
isPOJO[A] && A.tpe.typeSymbol.asClass.isCaseClass
def isCaseObject[A](implicit A: Type[A]): Boolean = {
val sym = A.tpe.typeSymbol
def isScala2Enum = sym.asClass.isCaseClass
def isScala3Enum = sym.isStatic && sym.isFinal // parameterless case in S3 cannot be checked for "case"
def isScalaEnum = sym.isModuleClass && (isScala2Enum || isScala3Enum)
sym.isPublic && (isScalaEnum || isJavaEnumValue(A.tpe))
sym.isPublic && sym.isModuleClass && sym.asClass.isCaseClass
}
def isCaseVal[A](implicit A: Type[A]): Boolean = {
val sym = A.tpe.typeSymbol
sym.isPublic && sym.isModuleClass && sym.isStatic && sym.isFinal // parameterless case in S3 cannot be checked for "case"
}
def isJavaEnumValue[A](implicit A: Type[A]): Boolean = {
val sym = A.tpe.typeSymbol
sym.isPublic && sym.isJavaEnum && javaEnumRegexpFormat.pattern
.matcher(A.tpe.toString)
.matches() // 2.12 doesn't have .matches
}
def isJavaBean[A](implicit A: Type[A]): Boolean = {
val mem = A.tpe.members
Expand Down Expand Up @@ -103,9 +110,9 @@ trait ProductTypesPlatform extends ProductTypes { this: DefinitionsPlatform =>
val A = Type[A].tpe
val sym = A.typeSymbol

if (isJavaEnumValue(A)) {
if (isJavaEnumValue[A]) {
Some(Product.Constructor(ListMap.empty, _ => c.Expr[A](q"$A")))
} else if (isCaseObject[A]) {
} else if (isCaseObject[A] || isCaseVal[A]) {
Some(Product.Constructor(ListMap.empty, _ => c.Expr[A](q"${sym.asClass.module}")))
} else if (isPOJO[A]) {
val primaryConstructor =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private[compiletime] trait ExprPromisesPlatform extends ExprPromises { this: Def

def matchOn[From: Type, To: Type](src: Expr[From], cases: List[PatternMatchCase[To]]): Expr[To] = Match(
'{ ${ src.asTerm.changeOwner(Symbol.spliceOwner).asExprOf[From] }: @scala.unchecked }.asTerm,
cases.map { case PatternMatchCase(someFrom, usage, fromName, isCaseObject) =>
cases.map { case PatternMatchCase(someFrom, usage, fromName) =>
import someFrom.Underlying as SomeFrom
// Unfortunately, we cannot do
// case $fromName: $SomeFrom => $using
Expand All @@ -126,7 +126,7 @@ private[compiletime] trait ExprPromisesPlatform extends ExprPromises { this: Def
// Scala 3's enums' parameterless cases are vals with type erased, so w have to match them by value
// case arg @ Enum.Value => ...
CaseDef(Bind(bindName, Ident(TypeRepr.of[someFrom.Underlying].typeSymbol.termRef)), None, body)
else if isCaseObject then
else if TypeRepr.of[someFrom.Underlying].typeSymbol.flags.is(Flags.Module) then
// case objects are also matched by value but the tree for them is generated in a slightly different way
// case arg @ Enum.Value => ...
CaseDef(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ trait ProductTypesPlatform extends ProductTypes { this: DefinitionsPlatform =>
def isJavaSetterOrVar(setter: Symbol): Boolean =
isJavaSetter(setter) || isVar(setter)

def isJavaEnumValue[A: Type]: Boolean =
Type[A] <:< scala.quoted.Type.of[java.lang.Enum[?]] && !TypeRepr.of[A].typeSymbol.isAbstract
}

import platformSpecific.*
Expand All @@ -61,10 +59,14 @@ trait ProductTypesPlatform extends ProductTypes { this: DefinitionsPlatform =>
}
def isCaseObject[A](implicit A: Type[A]): Boolean = {
val sym = TypeRepr.of(using A).typeSymbol
def isScala2Enum = sym.flags.is(Flags.Case | Flags.Module)
def isScala3Enum = sym.flags.is(Flags.Case | Flags.Enum | Flags.JavaStatic)
sym.isPublic && (isScala2Enum || isScala3Enum || isJavaEnumValue[A])
sym.isPublic && sym.flags.is(Flags.Case | Flags.Module)
}
def isCaseVal[A](implicit A: Type[A]): Boolean = {
val sym = TypeRepr.of(using A).typeSymbol
sym.isPublic && sym.flags.is(Flags.Case | Flags.Enum | Flags.JavaStatic)
}
def isJavaEnumValue[A: Type]: Boolean =
Type[A] <:< scala.quoted.Type.of[java.lang.Enum[?]] && !TypeRepr.of[A].typeSymbol.isAbstract
def isJavaBean[A](implicit A: Type[A]): Boolean = {
val sym = TypeRepr.of(using A).typeSymbol
val mem = sym.declarations
Expand Down Expand Up @@ -129,14 +131,11 @@ trait ProductTypesPlatform extends ProductTypes { this: DefinitionsPlatform =>
)

def parseConstructor[A: Type]: Option[Product.Constructor[A]] =
if isCaseObject[A] then {
if isCaseObject[A] || isCaseVal[A] || isJavaEnumValue[A] then {
val A = TypeRepr.of[A]
val sym = A.typeSymbol

// TODO: actually a duplication of isScala3Enum from isCaseObject
def isScala3Enum = sym.flags.is(Flags.Case | Flags.Enum | Flags.JavaStatic)

if isScala3Enum || isJavaEnumValue[A] then Some(Product.Constructor(ListMap.empty, _ => Ref(sym).asExprOf[A]))
if isCaseVal[A] || isJavaEnumValue[A] then Some(Product.Constructor(ListMap.empty, _ => Ref(sym).asExprOf[A]))
else Some(Product.Constructor(ListMap.empty, _ => Ref(sym.companionModule).asExprOf[A]))
} else if isPOJO[A] then {
val A = TypeRepr.of[A]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,11 @@ private[compiletime] trait ExprPromises { this: Definitions =>
): Expr[(From, From2) => To] =
ExprPromise.createLambda2(fromName, promise.fromName, combine(usage, promise.usage))

def fulfillAsPatternMatchCase[To](isCaseObject: Boolean)(implicit ev: A <:< Expr[To]): PatternMatchCase[To] =
def fulfillAsPatternMatchCase[To](implicit ev: A <:< Expr[To]): PatternMatchCase[To] =
new PatternMatchCase(
someFrom = ExistentialType(Type[From]),
someFrom = Type[From].as_??,
usage = ev(usage),
fromName = fromName,
isCaseObject = isCaseObject
fromName = fromName
)

def partition[L, R](implicit
Expand Down Expand Up @@ -218,18 +217,13 @@ private[compiletime] trait ExprPromises { this: Definitions =>
final protected class PatternMatchCase[To](
val someFrom: ??,
val usage: Expr[To],
val fromName: ExprPromiseName,
val isCaseObject: Boolean
val fromName: ExprPromiseName
)
protected val PatternMatchCase: PatternMatchCaseModule
protected trait PatternMatchCaseModule { this: PatternMatchCase.type =>

final def unapply[To](
patternMatchCase: PatternMatchCase[To]
): Some[(??, Expr[To], ExprPromiseName, Boolean)] =
Some(
(patternMatchCase.someFrom, patternMatchCase.usage, patternMatchCase.fromName, patternMatchCase.isCaseObject)
)
final def unapply[To](patternMatchCase: PatternMatchCase[To]): Some[(??, Expr[To], ExprPromiseName)] =
Some((patternMatchCase.someFrom, patternMatchCase.usage, patternMatchCase.fromName))

def matchOn[From: Type, To: Type](src: Expr[From], cases: List[PatternMatchCase[To]]): Expr[To]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,22 @@ trait ProductTypes { this: Definitions =>
protected val ProductType: ProductTypesModule
protected trait ProductTypesModule { this: ProductType.type =>

/** Any class with a public constructor... explicitly excluding: primitives, String and Java enums */
def isPOJO[A](implicit A: Type[A]): Boolean

/** Class defined with "case class" */
def isCaseClass[A](implicit A: Type[A]): Boolean

/** Class defined with "case object" */
def isCaseObject[A](implicit A: Type[A]): Boolean

/** Scala 3 enum's case without parameters (a "val" under the hood, NOT an "object") */
def isCaseVal[A](implicit A: Type[A]): Boolean

/** Java enum value - not the abstract enum type, but the concrete enum value */
def isJavaEnumValue[A](implicit A: Type[A]): Boolean

/** Any POJO with a public DEFAULT constructor... and at least 1 setter or var */
def isJavaBean[A](implicit A: Type[A]): Boolean

def parseExtraction[A: Type]: Option[Product.Extraction[A]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ private[compiletime] trait TransformSealedHierarchyToSealedHierarchyRuleModule {
// 2 or more matches - ambiguous coproduct instances
case toSubtypes =>
DerivationResult.ambiguousSubtypeTargets[From, To, Existential[ExprPromise[*, TransformationExpr[To]]]](
ExistentialType(fromSubtype.Underlying),
toSubtypes.map(to => ExistentialType(to.Underlying))
fromSubtype.Underlying.as_??,
toSubtypes.map(to => to.Underlying.as_??)
)
}
}
Expand Down Expand Up @@ -216,7 +216,7 @@ private[compiletime] trait TransformSealedHierarchyToSealedHierarchyRuleModule {
subtypeMappings
.map { subtype =>
subtype.value.ensurePartial
.fulfillAsPatternMatchCase[partial.Result[To]](isCaseObject = subtype.Underlying.isCaseObject)
.fulfillAsPatternMatchCase[partial.Result[To]]
}
.matchOn(ctx.src)
)
Expand All @@ -225,8 +225,7 @@ private[compiletime] trait TransformSealedHierarchyToSealedHierarchyRuleModule {
DerivationResult.expandedTotal(
subtypeMappings
.map { subtype =>
subtype.value.ensureTotal
.fulfillAsPatternMatchCase[To](isCaseObject = subtype.Underlying.isCaseObject)
subtype.value.ensureTotal.fulfillAsPatternMatchCase[To]
}
.matchOn(ctx.src)
)
Expand Down
Loading