Skip to content

Commit

Permalink
New footprint calculation scheme (scala#19639)
Browse files Browse the repository at this point in the history
Since match type reduction is expensive, it is cached. If a type is
reduced (or not reduced) in `tryNormalize` we remember that decision and
return the previous result - unless something in the context changed
since the last attempt which could lead to a different outcome. Relevant
here are:

 - constraints (regular and GADT) over type parameters
 - instantations of type variables

We keep track of these things in a so-called footprint calculation. 

The old calculation clearly did not work. It either never worked or was
broken by the changes to matchtype reduction.

I now changed it to a more straightforward scheme that computes the
footprint directly instead of relying on TypeComparer to produce the
right trace.
  • Loading branch information
sjrd authored Feb 14, 2024
2 parents b160bbb + d6ba9b2 commit f95b57c
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 57 deletions.
48 changes: 13 additions & 35 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3054,7 +3054,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
end provablyDisjointTypeArgs

protected def explainingTypeComparer(short: Boolean) = ExplainingTypeComparer(comparerContext, short)
protected def trackingTypeComparer = TrackingTypeComparer(comparerContext)
protected def matchReducer = MatchReducer(comparerContext)

private def inSubComparer[T, Cmp <: TypeComparer](comparer: Cmp)(op: Cmp => T): T =
val saved = myInstance
Expand All @@ -3068,8 +3068,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
inSubComparer(cmp)(op)
cmp.lastTrace(header)

def tracked[T](op: TrackingTypeComparer => T)(using Context): T =
inSubComparer(trackingTypeComparer)(op)
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
inSubComparer(matchReducer)(op)
}

object TypeComparer {
Expand Down Expand Up @@ -3236,14 +3236,14 @@ object TypeComparer {
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:", short: Boolean = false)(using Context): String =
comparing(_.explained(op, header, short))

def tracked[T](op: TrackingTypeComparer => T)(using Context): T =
comparing(_.tracked(op))
def reduceMatchWith[T](op: MatchReducer => T)(using Context): T =
comparing(_.reduceMatchWith(op))

def subCaptures(refs1: CaptureSet, refs2: CaptureSet, frozen: Boolean)(using Context): CaptureSet.CompareResult =
comparing(_.subCaptures(refs1, refs2, frozen))
}

object TrackingTypeComparer:
object MatchReducer:
import printing.*, Texts.*
enum MatchResult extends Showable:
case Reduced(tp: Type)
Expand All @@ -3259,38 +3259,16 @@ object TrackingTypeComparer:
case Stuck => "Stuck"
case NoInstance(fails) => "NoInstance(" ~ Text(fails.map(p.toText(_) ~ p.toText(_)), ", ") ~ ")"

class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
import TrackingTypeComparer.*
/** A type comparer for reducing match types.
* TODO: Not sure this needs to be a type comparer. Can we make it a
* separate class?
*/
class MatchReducer(initctx: Context) extends TypeComparer(initctx) {
import MatchReducer.*

init(initctx)

override def trackingTypeComparer = this

val footprint: mutable.Set[Type] = mutable.Set[Type]()

override def bounds(param: TypeParamRef)(using Context): TypeBounds = {
if (param.binder `ne` caseLambda) footprint += param
super.bounds(param)
}

override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean = {
if (param.binder `ne` caseLambda) footprint += param
super.addOneBound(param, bound, isUpper)
}

override def gadtBounds(sym: Symbol)(using Context): TypeBounds | Null = {
if (sym.exists) footprint += sym.typeRef
super.gadtBounds(sym)
}

override def gadtAddBound(sym: Symbol, b: Type, isUpper: Boolean): Boolean =
if (sym.exists) footprint += sym.typeRef
super.gadtAddBound(sym, b, isUpper)

override def typeVarInstance(tvar: TypeVar)(using Context): Type = {
footprint += tvar
super.typeVarInstance(tvar)
}
override def matchReducer = this

def matchCases(scrut: Type, cases: List[MatchTypeCaseSpec])(using Context): Type = {
// a reference for the type parameters poisoned during matching
Expand Down
75 changes: 53 additions & 22 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5009,6 +5009,8 @@ object Types extends TypeUtils {
case ex: Throwable =>
handleRecursive("normalizing", s"${scrutinee.show} match ..." , ex)

private def thisMatchType = this

def reduced(using Context): Type = {

def contextInfo(tp: Type): Type = tp match {
Expand All @@ -5023,16 +5025,43 @@ object Types extends TypeUtils {
tp.underlying
}

def updateReductionContext(footprint: collection.Set[Type]): Unit =
reductionContext = util.HashMap()
for (tp <- footprint)
reductionContext(tp) = contextInfo(tp)
typr.println(i"footprint for $this $hashCode: ${footprint.toList.map(x => (x, contextInfo(x)))}%, %")

def isUpToDate: Boolean =
reductionContext.keysIterator.forall { tp =>
reductionContext.keysIterator.forall: tp =>
reductionContext(tp) `eq` contextInfo(tp)
}

def setReductionContext(): Unit =
new TypeTraverser:
var footprint: Set[Type] = Set()
var deep: Boolean = true
val seen = util.HashSet[Type]()
def traverse(tp: Type) =
if !seen.contains(tp) then
seen += tp
tp match
case tp: NamedType =>
if tp.symbol.is(TypeParam) then footprint += tp
traverseChildren(tp)
case _: AppliedType | _: RefinedType =>
if deep then traverseChildren(tp)
case TypeBounds(lo, hi) =>
traverse(hi)
case tp: TypeVar =>
footprint += tp
traverse(tp.underlying)
case tp: TypeParamRef =>
footprint += tp
case _ =>
traverseChildren(tp)
end traverse

traverse(scrutinee)
deep = false
cases.foreach(traverse)
reductionContext = util.HashMap()
for tp <- footprint do
reductionContext(tp) = contextInfo(tp)
matchTypes.println(i"footprint for $thisMatchType $hashCode: ${footprint.toList.map(x => (x, contextInfo(x)))}%, %")
end setReductionContext

record("MatchType.reduce called")
if !Config.cacheMatchReduced
Expand All @@ -5043,20 +5072,22 @@ object Types extends TypeUtils {
record("MatchType.reduce computed")
if (myReduced != null) record("MatchType.reduce cache miss")
myReduced =
trace(i"reduce match type $this $hashCode", matchTypes, show = true)(withMode(Mode.Type) {
def matchCases(cmp: TrackingTypeComparer): Type =
val saved = ctx.typerState.snapshot()
try cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
catch case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
finally
updateReductionContext(cmp.footprint)
ctx.typerState.resetTo(saved)
// this drops caseLambdas in constraint and undoes any typevar
// instantiations during matchtype reduction

TypeComparer.tracked(matchCases)
})
trace(i"reduce match type $this $hashCode", matchTypes, show = true):
withMode(Mode.Type):
setReductionContext()
def matchCases(cmp: MatchReducer): Type =
val saved = ctx.typerState.snapshot()
try
cmp.matchCases(scrutinee.normalized, cases.map(MatchTypeCaseSpec.analyze(_)))
catch case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
finally
ctx.typerState.resetTo(saved)
// this drops caseLambdas in constraint and undoes any typevar
// instantiations during matchtype reduction
TypeComparer.reduceMatchWith(matchCases)

//else println(i"no change for $this $hashCode / $myReduced")
myReduced.nn
}

Expand Down
99 changes: 99 additions & 0 deletions tests/pos/bad-footprint.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

object NamedTuple:

opaque type AnyNamedTuple = Any
opaque type NamedTuple[N <: Tuple, +V <: Tuple] >: V <: AnyNamedTuple = V

export NamedTupleDecomposition.{Names, DropNames}

/** The type of the named tuple `X` mapped with the type-level function `F`.
* If `X = (n1 : T1, ..., ni : Ti)` then `Map[X, F] = `(n1 : F[T1], ..., ni : F[Ti])`.
*/
type Map[X <: AnyNamedTuple, F[_ <: Tuple.Union[DropNames[X]]]] =
NamedTuple[Names[X], Tuple.Map[DropNames[X], F]]

end NamedTuple

object NamedTupleDecomposition:
import NamedTuple.*

/** The names of a named tuple, represented as a tuple of literal string values. */
type Names[X <: AnyNamedTuple] <: Tuple = X match
case NamedTuple[n, _] => n

/** The value types of a named tuple represented as a regular tuple. */
type DropNames[NT <: AnyNamedTuple] <: Tuple = NT match
case NamedTuple[_, x] => x
end NamedTupleDecomposition

class Expr[Result]

object Expr:
import NamedTuple.{NamedTuple, AnyNamedTuple}

type Of[A] = Expr[A]

type StripExpr[E] = E match
case Expr.Of[b] => b

case class Ref[A]($name: String = "") extends Expr.Of[A]

case class Join[A <: AnyNamedTuple](a: A)
extends Expr.Of[NamedTuple.Map[A, StripExpr]]
end Expr

trait Query[A]

object Query:
// Extension methods to support for-expression syntax for queries
extension [R](x: Query[R])
def map[B](f: Expr.Ref[R] => Expr.Of[B]): Query[B] = ???

case class City(zipCode: Int, name: String, population: Int)

object Test:
import Expr.StripExpr
import NamedTuple.{NamedTuple, AnyNamedTuple}

val cities: Query[City] = ???
val q6 =
cities.map: city =>
val x: NamedTuple[
("name", "zipCode"),
(Expr.Of[String], Expr.Of[Int])] = ???
Expr.Join(x)

/* Was error:
-- [E007] Type Mismatch Error: bad-footprint.scala:60:16 -----------------------
60 | cities.map: city =>
| ^
|Found: Expr.Ref[City] =>
| Expr[
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (String,
| Int)]
| ]
|Required: Expr.Ref[City] =>
| Expr[
| NamedTuple.NamedTuple[
| NamedTupleDecomposition.Names[
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (
| Expr[String], Expr[Int])]
| ],
| Tuple.Map[
| NamedTupleDecomposition.DropNames[
| NamedTuple.NamedTuple[(("name" : String), ("zipCode" : String)), (
| Expr[String], Expr[Int])]
| ],
| Expr.StripExpr]
| ]
| ]
61 | val x: NamedTuple[
62 | ("name", "zipCode"),
63 | (Expr.Of[String], Expr.Of[Int])] = ???
64 | Expr.Join(x)
|
| longer explanation available when compiling with `-explain`
1 error found
*/

0 comments on commit f95b57c

Please sign in to comment.