Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pattern matching in global init checker, skipping cases when safe to do so #22179

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ class Objects(using Context @constructorOnly):
case (ValueSet(values), b : ValueElement) => ValueSet(values + b)
case (a : ValueElement, b : ValueElement) => ValueSet(ListSet(a, b))

def remove(b: Value): Value = (a, b) match
case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b)
case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2))
case (a: Ref, b: Ref) if a.equals(b) => Bottom
case _ => a

def widen(height: Int)(using Context): Value =
if height == 0 then Cold
else
Expand Down Expand Up @@ -1341,29 +1347,25 @@ class Objects(using Context @constructorOnly):
def getMemberMethod(receiver: Type, name: TermName, tp: Type): Denotation =
receiver.member(name).suchThat(receiver.memberInfo(_) <:< tp)

def evalCase(caseDef: CaseDef): Value =
evalPattern(scrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
eval(caseDef.body, thisV, klass)

/** Abstract evaluation of patterns.
*
* It augments the local environment for bound pattern variables. As symbols are globally
* unique, we can put them in a single environment.
*
* Currently, we assume all cases are reachable, thus all patterns are assumed to match.
*/
def evalPattern(scrutinee: Value, pat: Tree): Value = log("match " + scrutinee.show + " against " + pat.show, printer, (_: Value).show):
def evalPattern(scrutinee: Value, pat: Tree): (Type, Value) = log("match " + scrutinee.show + " against " + pat.show, printer, (_: (Type, Value))._2.show):
val trace2 = Trace.trace.add(pat)
pat match
case Alternative(pats) =>
for pat <- pats do evalPattern(scrutinee, pat)
scrutinee
val (types, values) = pats.map(evalPattern(scrutinee, _)).unzip()
val orType = types.fold(defn.NothingType)(OrType(_, _, false))
(orType, values.join)

case bind @ Bind(_, pat) =>
val value = evalPattern(scrutinee, pat)
val (tpe, value) = evalPattern(scrutinee, pat)
initLocal(bind.symbol, value)
scrutinee
(tpe, value)

case UnApply(fun, implicits, pats) =>
given Trace = trace2
Expand All @@ -1372,6 +1374,10 @@ class Objects(using Context @constructorOnly):
val funRef = fun1.tpe.asInstanceOf[TermRef]
val unapplyResTp = funRef.widen.finalResultType

val receiverType = fun1 match
case ident: Ident => funRef.prefix
case select: Select => select.qualifier.tpe

val receiver = fun1 match
case ident: Ident =>
evalType(funRef.prefix, thisV, klass)
Expand Down Expand Up @@ -1460,17 +1466,18 @@ class Objects(using Context @constructorOnly):
end if
end if
end if
scrutinee
(receiverType, scrutinee.filterType(receiverType))

case Ident(nme.WILDCARD) | Ident(nme.WILDCARD_STAR) =>
scrutinee
(defn.ThrowableType, scrutinee)

case Typed(pat, _) =>
evalPattern(scrutinee, pat)
case Typed(pat, typeTree) =>
val (_, value) = evalPattern(scrutinee.filterType(typeTree.tpe), pat)
(typeTree.tpe, value)

case tree =>
// For all other trees, the semantics is normal.
eval(tree, thisV, klass)
(defn.ThrowableType, eval(tree, thisV, klass))

end evalPattern

Expand All @@ -1494,12 +1501,12 @@ class Objects(using Context @constructorOnly):
if isWildcardStarArgList(pats) then
if pats.size == 1 then
// call .toSeq
val toSeqDenot = getMemberMethod(scrutineeType, nme.toSeq, toSeqType(elemType))
val toSeqDenot = scrutineeType.member(nme.toSeq).suchThat(_.info.isParameterless)
val toSeqRes = call(scrutinee, toSeqDenot.symbol, Nil, scrutineeType, superType = NoType, needResolve = true)
evalPattern(toSeqRes, pats.head)
else
// call .drop
val dropDenot = getMemberMethod(scrutineeType, nme.drop, dropType(elemType))
val dropDenot = getMemberMethod(scrutineeType, nme.drop, applyType(elemType))
val dropRes = call(scrutinee, dropDenot.symbol, ArgInfo(Bottom, summon[Trace], EmptyTree) :: Nil, scrutineeType, superType = NoType, needResolve = true)
for pat <- pats.init do evalPattern(applyRes, pat)
evalPattern(dropRes, pats.last)
Expand All @@ -1510,8 +1517,21 @@ class Objects(using Context @constructorOnly):
end if
end evalSeqPatterns

def canSkipCase(remainingScrutinee: Value, catchValue: Value) =
(remainingScrutinee == Bottom && scrutinee != Bottom) ||
(catchValue == Bottom && remainingScrutinee != Bottom)

cases.map(evalCase).join
var remainingScrutinee = scrutinee
val caseResults: mutable.ArrayBuffer[Value] = mutable.ArrayBuffer()
for caseDef <- cases do
val (tpe, value) = evalPattern(remainingScrutinee, caseDef.pat)
eval(caseDef.guard, thisV, klass)
if !canSkipCase(remainingScrutinee, value) then
caseResults.addOne(eval(caseDef.body, thisV, klass))
if catchesAllOf(caseDef, tpe) then
remainingScrutinee = remainingScrutinee.remove(value)

caseResults.join
end patternMatch

/** Handle semantics of leaf nodes
Expand Down
Loading