Skip to content

Commit

Permalink
Fix eta expand at erasure to take erased CFTs into account
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Nov 3, 2021
1 parent 88a2477 commit f50b55c
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 99 deletions.
19 changes: 17 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
Closure(Nil, call, targetTpt))
}

/** A closure whole anonymous function has the given method type */
/** A closure whose anonymous function has the given method type */
def Lambda(tpe: MethodType, rhsFn: List[Tree] => Tree)(using Context): Block = {
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe)
val meth = newAnonFun(ctx.owner, tpe)
Closure(meth, tss => rhsFn(tss.head).changeOwner(ctx.owner, meth))
}

Expand Down Expand Up @@ -1104,6 +1104,21 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
if (sym.exists) sym.defTree = tree
tree
}

def etaExpandCFT(using Context): Tree =
def expand(target: Tree, tp: Type)(using Context): Tree = tp match
case defn.ContextFunctionType(argTypes, resType, isErased) =>
val anonFun = newAnonFun(
ctx.owner,
MethodType.companion(isContextual = true, isErased = isErased)(argTypes, resType),
coord = ctx.owner.coord)
def lambdaBody(refss: List[List[Tree]]) =
expand(target.select(nme.apply).appliedToArgss(refss), resType)(
using ctx.withOwner(anonFun))
Closure(anonFun, lambdaBody)
case _ =>
target
expand(tree, tree.tpe.widen)
}

inline val MapRecursionLimit = 10
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,10 @@ object Symbols {
coord: Coord = NoCoord)(using Context): TermSymbol =
newSymbol(cls, nme.CONSTRUCTOR, flags | Method, MethodType(paramNames, paramTypes, cls.typeRef), privateWithin, coord)

/** Create an anonymous function symbol */
def newAnonFun(owner: Symbol, info: Type, coord: Coord = NoCoord)(using Context): TermSymbol =
newSymbol(owner, nme.ANON_FUN, Synthetic | Method, info, coord = coord)

/** Create an empty default constructor symbol for given class `cls`. */
def newDefaultConstructor(cls: ClassSymbol)(using Context): TermSymbol =
newConstructor(cls, EmptyFlags, Nil, Nil)
Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/AccessProxies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ abstract class AccessProxies {
forwardedArgss.nonEmpty && forwardedArgss.head.nonEmpty) // defensive conditions
accessRef.becomes(forwardedArgss.head.head)
else
accessRef.appliedToTypeTrees(forwardedTpts).appliedToArgss(forwardedArgss)
accessRef
.appliedToTypeTrees(forwardedTpts)
.appliedToArgss(forwardedArgss)
.etaExpandCFT(using ctx.withOwner(accessor))
rhs.withSpan(accessed.span)
})

Expand Down
53 changes: 49 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/Bridges.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import ast.untpd
import collection.{mutable, immutable}
import util.Spans.Span
import util.SrcPos
import ContextFunctionResults.{contextResultCount, contextFunctionResultTypeAfter}
import StdNames.nme
import Constants.Constant
import TypeErasure.transformInfo
import Erasure.Boxing.adaptClosure

/** A helper class for generating bridge methods in class `root`. */
class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
Expand Down Expand Up @@ -112,12 +117,52 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
toBeRemoved += other
}

def bridgeRhs(argss: List[List[Tree]]) = {
val memberCount = contextResultCount(member)

/** Eta expand application `ref(args)` as needed.
* To do this correctly, we have to look at the member's original pre-erasure
* type and figure out which context function types in its result are
* not yet instantiated.
*/
def etaExpand(ref: Tree, args: List[Tree])(using Context): Tree =
def expand(args: List[Tree], tp: Type, n: Int)(using Context): Tree =
if n <= 0 then
assert(ctx.typer.isInstanceOf[Erasure.Typer])
ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType)
else
val defn.ContextFunctionType(argTypes, resType, isErased) = tp: @unchecked
val anonFun = newAnonFun(ctx.owner,
MethodType(if isErased then Nil else argTypes, resType),
coord = ctx.owner.coord)
anonFun.info = transformInfo(anonFun, anonFun.info)

def lambdaBody(refss: List[List[Tree]]) =
val refs :: Nil = refss: @unchecked
val expandedRefs = refs.map(_.withSpan(ctx.owner.span.endPos)) match
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
argTypes.indices.toList.map(n =>
bunchedParam
.select(nme.primitive.arrayApply)
.appliedTo(Literal(Constant(n))))
case refs1 => refs1
expand(args ::: expandedRefs, resType, n - 1)(using ctx.withOwner(anonFun))

val unadapted = Closure(anonFun, lambdaBody)
cpy.Block(unadapted)(unadapted.stats,
adaptClosure(unadapted.expr.asInstanceOf[Closure]))
end expand

val otherCount = contextResultCount(other)
val start = contextFunctionResultTypeAfter(member, otherCount)(using preErasureCtx)
expand(args, start, memberCount - otherCount)(using ctx.withOwner(bridge))
end etaExpand

def bridgeRhs(argss: List[List[Tree]]) =
assert(argss.tail.isEmpty)
val ref = This(root).select(member)
if (member.info.isParameterless) ref // can happen if `member` is a module
else Erasure.partialApply(ref, argss.head)
}
if member.info.isParameterless then ref // can happen if `member` is a module
else if memberCount == 0 then ref.appliedToTermArgs(argss.head)
else etaExpand(ref, argss.head)

bridges += DefDef(bridge, bridgeRhs(_).withSpan(bridge.span))
}
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/ByNameClosures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ class ByNameClosures extends TransformByNameApply with IdentityDenotTransformer
// ExpanSAMs applied to partial functions creates methods that need
// to be fully defined before converting. Test case is pos/i9391.scala.

override def mkByNameClosure(arg: Tree, argType: Type)(using Context): Tree = {
val meth = newSymbol(
ctx.owner, nme.ANON_FUN, Synthetic | Method, MethodType(Nil, Nil, argType))
override def mkByNameClosure(arg: Tree, argType: Type)(using Context): Tree =
val meth = newAnonFun(ctx.owner, MethodType(Nil, argType))
Closure(meth, _ => arg.changeOwnerAfter(ctx.owner, meth, thisPhase)).withSpan(arg.span)
}
}

object ByNameClosures {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,33 +99,13 @@ object ContextFunctionResults:
normalParamCount(sym.info)
end totalParamCount

/** The rightmost context function type in the result type of `meth`
* that represents `paramCount` curried, non-erased parameters that
* are included in the `contextResultCount` of `meth`.
* Example:
*
* Say we have `def m(x: A): B ?=> (C1, C2, C3) ?=> D ?=> E ?=> F`,
* paramCount == 4, and the contextResultCount of `m` is 3.
* Then we return the type `(C1, C2, C3) ?=> D ?=> E ?=> F`, since this
* type covers the 4 rightmost parameters C1, C2, C3 and D before the
* contextResultCount runs out at E ?=> F.
* Erased parameters are ignored; they contribute nothing to the
* parameter count.
*/
def contextFunctionResultTypeCovering(meth: Symbol, paramCount: Int)(using Context) =
atPhase(erasurePhase) {
// Recursive instances return pairs of context types and the
// # of parameters they represent.
def missingCR(tp: Type, crCount: Int): (Type, Int) =
if crCount == 0 then (tp, 0)
else
val defn.ContextFunctionType(formals, resTpe, isErased) = tp: @unchecked
val result @ (rt, nparams) = missingCR(resTpe, crCount - 1)
assert(nparams <= paramCount)
if nparams == paramCount || isErased then result
else (tp, nparams + formals.length)
missingCR(meth.info.finalResultType, contextResultCount(meth))._1
}
/** The `depth` levels nested context function type in the result type of `meth` */
def contextFunctionResultTypeAfter(meth: Symbol, depth: Int)(using Context) =
def recur(tp: Type, n: Int): Type =
if n == 0 then tp
else tp match
case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1)
recur(meth.info.finalResultType, depth)

/** Should selection `tree` be eliminated since it refers to an `apply`
* node of a context function type whose parameters will end up being
Expand Down
60 changes: 3 additions & 57 deletions compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Erasure extends Phase with DenotTransformer {
case _ => false
}
}

def erasedName =
if ref.is(Flags.Method)
&& contextResultsAreErased(ref.symbol)
Expand Down Expand Up @@ -383,8 +384,8 @@ object Erasure {
case _: FunProto | AnyFunctionProto => tree
case _ => tree.tpe.widen match
case mt: MethodType if tree.isTerm =>
if mt.paramInfos.isEmpty then adaptToType(tree.appliedToNone, pt)
else etaExpand(tree, mt, pt)
assert(mt.paramInfos.isEmpty)
adaptToType(tree.appliedToNone, pt)
case tpw =>
if (pt.isInstanceOf[ProtoType] || tree.tpe <:< pt)
tree
Expand Down Expand Up @@ -533,61 +534,6 @@ object Erasure {
else
tree
end adaptClosure

/** Eta expand given `tree` that has the given method type `mt`, so that
* it conforms to erased result type `pt`.
* To do this correctly, we have to look at the tree's original pre-erasure
* type and figure out which context function types in its result are
* not yet instantiated.
*/
def etaExpand(tree: Tree, mt: MethodType, pt: Type)(using Context): Tree =
report.log(i"eta expanding $tree")
val defs = new mutable.ListBuffer[Tree]
val tree1 = LiftErased.liftApp(defs, tree)
val xmt = if tree.isInstanceOf[Apply] then mt else expandedMethodType(mt, tree)
val targetLength = xmt.paramInfos.length
val origOwner = ctx.owner

// The original type from which closures should be constructed
val origType = contextFunctionResultTypeCovering(tree.symbol, targetLength)

def abstracted(args: List[Tree], tp: Type, pt: Type)(using Context): Tree =
if args.length < targetLength then
try
val defn.ContextFunctionType(argTpes, resTpe, isErased) = tp: @unchecked
if isErased then abstracted(args, resTpe, pt)
else
val anonFun = newSymbol(
ctx.owner, nme.ANON_FUN, Flags.Synthetic | Flags.Method,
MethodType(argTpes, resTpe), coord = tree.span.endPos)
anonFun.info = transformInfo(anonFun, anonFun.info)
def lambdaBody(refss: List[List[Tree]]) =
val refs :: Nil = refss: @unchecked
val expandedRefs = refs.map(_.withSpan(tree.span.endPos)) match
case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil =>
argTpes.indices.toList.map(n =>
bunchedParam
.select(nme.primitive.arrayApply)
.appliedTo(Literal(Constant(n))))
case refs1 => refs1
abstracted(args ::: expandedRefs, resTpe, anonFun.info.finalResultType)(
using ctx.withOwner(anonFun))

val unadapted = Closure(anonFun, lambdaBody)
cpy.Block(unadapted)(unadapted.stats, adaptClosure(unadapted.expr.asInstanceOf[Closure]))
catch case ex: MatchError =>
println(i"error while abstracting tree = $tree | mt = $mt | args = $args%, % | tp = $tp | pt = $pt")
throw ex
else
assert(args.length == targetLength, i"wrong # args tree = $tree | args = $args%, % | mt = $mt | tree type = ${tree.tpe}")
val app = untpd.cpy.Apply(tree1)(tree1, args)
assert(ctx.typer.isInstanceOf[Erasure.Typer])
ctx.typer.typed(app, pt)
.changeOwnerAfter(origOwner, ctx.owner, erasurePhase.asInstanceOf[Erasure])

seq(defs.toList, abstracted(Nil, origType, pt))
end etaExpand

end Boxing

class Typer(erasurePhase: DenotTransformer) extends typer.ReTyper with NoChecking {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
ref(defn.NoneModule))
}
val tpe = MethodType(List(nme.s))(_ => List(tp1), mth => defn.OptionClass.typeRef.appliedTo(mth.newParamRef(0) & tp2))
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe, coord = span)
val meth = newAnonFun(ctx.owner, tpe, coord = span)
val typeTestType = defn.TypeTestClass.typeRef.appliedTo(List(tp1, tp2))
Closure(meth, tss => body(tss.head).changeOwner(ctx.owner, meth), targetType = typeTestType).withSpan(span)
case _ =>
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ object QuoteMatcher {
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, methTpe)
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
case t => t
}
val closureTpe = Types.MethodType(mtpe.paramNames, mtpe.paramInfos, closureResType)
val closureMethod = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, closureTpe)
val closureMethod = dotc.core.Symbols.newAnonFun(owner, closureTpe)
tpd.Closure(closureMethod, tss => new tpd.TreeOps(self).appliedToTermArgs(tss.head).etaExpand(closureMethod))
case _ => self
}
Expand Down Expand Up @@ -793,7 +793,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler

object Lambda extends LambdaModule:
def apply(owner: Symbol, tpe: MethodType, rhsFn: (Symbol, List[Tree]) => Tree): Block =
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
val meth = dotc.core.Symbols.newAnonFun(owner, tpe)
withDefaultPos(tpd.Closure(meth, tss => xCheckMacroedOwners(xCheckMacroValidExpr(rhsFn(meth, tss.head.map(withDefaultPos))), meth)))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
Expand Down
53 changes: 53 additions & 0 deletions tests/run/i13691.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import language.experimental.erasedDefinitions

erased class CanThrow[-E <: Exception]
erased class Foo
class Bar

object unsafeExceptions:
given canThrowAny: CanThrow[Exception] = null

object test1:
trait Decoder[+T]:
def apply(): T

def deco: Decoder[CanThrow[Exception] ?=> Int] = new Decoder[CanThrow[Exception] ?=> Int]:
def apply(): CanThrow[Exception] ?=> Int = 1

object test2:
trait Decoder[+T]:
def apply(): T

def deco: Decoder[(CanThrow[Exception], Foo) ?=> Int] = new Decoder[(CanThrow[Exception], Foo) ?=> Int]:
def apply(): (CanThrow[Exception], Foo) ?=> Int = 1

object test3:
trait Decoder[+T]:
def apply(): T

def deco: Decoder[CanThrow[Exception] ?=> Foo ?=> Int] = new Decoder[CanThrow[Exception] ?=> Foo ?=> Int]:
def apply(): CanThrow[Exception] ?=> Foo ?=> Int = 1

object test4:
trait Decoder[+T]:
def apply(): T

def deco: Decoder[CanThrow[Exception] ?=> Bar ?=> Int] = new Decoder[CanThrow[Exception] ?=> Bar ?=> Int]:
def apply(): CanThrow[Exception] ?=> Bar ?=> Int = 1

object test5:
trait Decoder[+T]:
def apply(): T

def deco: Decoder[Bar ?=> CanThrow[Exception] ?=> Int] = new Decoder[Bar ?=> CanThrow[Exception] ?=> Int]:
def apply(): Bar ?=> CanThrow[Exception] ?=> Int = 1

@main def Test(): Unit =
import unsafeExceptions.canThrowAny
given Foo = ???
given Bar = Bar()
test1.deco.apply().apply
test2.deco.apply().apply
test3.deco.apply().apply
test4.deco.apply().apply
test5.deco.apply().apply
11 changes: 11 additions & 0 deletions tests/run/i13961a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import language.experimental.saferExceptions

trait Decoder[+T]:
def apply(): T

given Decoder[Int throws Exception] = new Decoder[Int throws Exception]:
def apply(): Int throws Exception = 1

@main def Test(): Unit =
import unsafeExceptions.canThrowAny
summon[Decoder[Int throws Exception]]()

0 comments on commit f50b55c

Please sign in to comment.