Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context Bounds for Polymorphic Functions #21643

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
a76470f
Implement basic version of desugaring context bounds for poly functions
KacperFKorban Sep 24, 2024
8dc68c3
Handle named context bounds in poly function context bound desugaring
KacperFKorban Sep 24, 2024
3ac5cec
Correctly-ish desugar poly function context bounds in function types
KacperFKorban Sep 24, 2024
408aa74
Fix pickling issue
KacperFKorban Sep 24, 2024
134c015
Hide context bounds expansion for poly functions under modularity fea…
KacperFKorban Sep 24, 2024
309034e
Small cleanup
KacperFKorban Sep 25, 2024
64bd03e
Add more test cases
KacperFKorban Sep 25, 2024
5196efd
Change the implementation of context bound expansion for poly functio…
KacperFKorban Oct 2, 2024
5f0d4a7
Add support for some type aliases, when expanding context bounds for …
KacperFKorban Oct 3, 2024
a736592
Make the expandion of context bounds for poly types slightly more ele…
KacperFKorban Oct 4, 2024
0af8397
Add more aliases tests for context bounds with poly functions
KacperFKorban Oct 8, 2024
9c66069
Bring back the restriction for requiring value parameters in poly fun…
KacperFKorban Oct 14, 2024
ec6d7ef
Cleanup dead code
KacperFKorban Oct 18, 2024
dfa9240
Reuse addEvidenceParams logic, but no aliases
KacperFKorban Nov 14, 2024
7755e3b
Cleanup context bounds for poly functions implementation, make the im…
KacperFKorban Nov 14, 2024
24e3fa0
More cleanup of poly context bound desugaring
KacperFKorban Nov 14, 2024
f292ac5
Short circuit adding evidence params to poly functions, when there ar…
KacperFKorban Nov 14, 2024
f9db9fa
Add a run test for poly context bounds; cleanup typer changes
KacperFKorban Nov 15, 2024
952eff7
Cleanup context bounds for poly functions implementation after review
KacperFKorban Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 139 additions & 63 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ object desugar {
*/
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()

/** Marks a poly fcuntion apply method, so that we can handle adding evidence parameters to them in a special way
*/
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()

/** What static check should be applied to a Match? */
enum MatchCheck {
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
Expand Down Expand Up @@ -242,7 +246,7 @@ object desugar {
* def f$default$2[T](x: Int) = x + "m"
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor).asInstanceOf[DefDef])

/** Drop context bounds in given TypeDef, replacing them with evidence ValDefs that
* get added to a buffer.
Expand Down Expand Up @@ -304,10 +308,8 @@ object desugar {
tdef1
end desugarContextBounds

private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
val DefDef(_, paramss, tpt, rhs) = meth
def elimContextBounds(meth: Tree, isPrimaryConstructor: Boolean = false)(using Context): Tree =
val evidenceParamBuf = mutable.ListBuffer[ValDef]()

var seenContextBounds: Int = 0
def freshName(unused: Tree) =
seenContextBounds += 1 // Start at 1 like FreshNameCreator.
Expand All @@ -317,7 +319,7 @@ object desugar {
// parameters of the method since shadowing does not affect
// implicit resolution in Scala 3.

val paramssNoContextBounds =
def paramssNoContextBounds(paramss: List[ParamClause]): List[ParamClause] =
val iflag = paramss.lastOption.flatMap(_.headOption) match
case Some(param) if param.mods.isOneOf(GivenOrImplicit) =>
param.mods.flags & GivenOrImplicit
Expand All @@ -329,15 +331,32 @@ object desugar {
tparam => desugarContextBounds(tparam, evidenceParamBuf, flags, freshName, paramss)
}(identity)

rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = paramssNoContextBounds),
evidenceParamBuf.toList)
meth match
case meth @ DefDef(_, paramss, tpt, rhs) =>
val newParamss = paramssNoContextBounds(paramss)
rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = newParamss
),
evidenceParamBuf.toList
)
case meth @ PolyFunction(tparams, fun) =>
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = meth: @unchecked
val Function(vparams: List[untpd.ValDef] @unchecked, rhs) = fun: @unchecked
val newParamss = paramssNoContextBounds(tparams :: vparams :: Nil)
val params = evidenceParamBuf.toList
if params.isEmpty then
meth
else
val boundNames = getBoundNames(params, newParamss)
val recur = fitEvidenceParams(params, nme.apply, boundNames)
val (paramsFst, paramsSnd) = recur(newParamss)
functionsOf((paramsFst ++ paramsSnd).filter(_.nonEmpty), rhs)
end elimContextBounds

def addDefaultGetters(meth: DefDef)(using Context): Tree =
Expand Down Expand Up @@ -465,6 +484,74 @@ object desugar {
case _ =>
(Nil, tree)

private def referencesName(vdef: ValDef, names: Set[TermName])(using Context): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => names.contains(name)
case _ => false

/** Fit evidence `params` into the `mparamss` parameter lists, making sure
* that all parameters referencing `params` are after them.
* - for methods the final parameter lists are := result._1 ++ result._2
* - for poly functions, each element of the pair contains at most one term
* parameter list
*
* @param params the evidence parameters list that should fit into `mparamss`
* @param methName the name of the method that `mparamss` belongs to
* @param boundNames the names of the evidence parameters
* @param mparamss the original parameter lists of the method
* @return a pair of parameter lists containing all parameter lists in a
* reference-correct order; make sure that `params` is always at the
* intersection of the pair elements; this is relevant, for poly functions
* where `mparamss` is guaranteed to have exectly one term parameter list,
* then each pair element will have at most one term parameter list
*/
private def fitEvidenceParams(
params: List[ValDef],
methName: Name,
boundNames: Set[TermName]
)(mparamss: List[ParamClause])(using Context): (List[ParamClause], List[ParamClause]) = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesName(_, boundNames)) =>
(params :: Nil) -> mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${methName}", Printers.desugar)
else params
((normParams ++ mparams) :: Nil) -> Nil
case mparams :: mparamss1 =>
val (fst, snd) = fitEvidenceParams(params, methName, boundNames)(mparamss1)
(mparams :: fst) -> snd
case Nil =>
Nil -> (params :: Nil)

/** Create a chain of possibly contextual functions from the parameter lists */
private def functionsOf(paramss: List[ParamClause], rhs: Tree)(using Context): Tree = paramss match
case Nil => rhs
case ValDefs(head @ (fst :: _)) :: rest if fst.mods.isOneOf(GivenOrImplicit) =>
val paramTpts = head.map(_.tpt)
val paramNames = head.map(_.name)
val paramsErased = head.map(_.mods.flags.is(Erased))
makeContextualFunction(paramTpts, paramNames, functionsOf(rest, rhs), paramsErased).withSpan(rhs.span)
case ValDefs(head) :: rest =>
Function(head, functionsOf(rest, rhs))
case TypeDefs(head) :: rest =>
PolyFunction(head, functionsOf(rest, rhs))
case _ =>
assert(false, i"unexpected paramss $paramss")
EmptyTree

private def getBoundNames(params: List[ValDef], paramss: List[ParamClause])(using Context): Set[TermName] =
var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
boundNames

/** Add all evidence parameters in `params` as implicit parameters to `meth`.
* The position of the added parameters is determined as follows:
*
Expand All @@ -479,36 +566,23 @@ object desugar {
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
if params.isEmpty then return meth

var boundNames = params.map(_.name).toSet // all evidence parameter + context bound proxy names
for mparams <- meth.paramss; mparam <- mparams do
mparam match
case tparam: TypeDef if tparam.mods.annotations.exists(WitnessNamesAnnot.unapply(_).isDefined) =>
boundNames += tparam.name.toTermName
case _ =>
val boundNames = getBoundNames(params, meth.paramss)

def referencesBoundName(vdef: ValDef): Boolean =
vdef.tpt.existsSubTree:
case Ident(name: TermName) => boundNames.contains(name)
case _ => false
val fitParams = fitEvidenceParams(params, meth.name, boundNames)

def recur(mparamss: List[ParamClause]): List[ParamClause] = mparamss match
case ValDefs(mparams) :: _ if mparams.exists(referencesBoundName) =>
params :: mparamss
case ValDefs(mparams @ (mparam :: _)) :: Nil if mparam.mods.isOneOf(GivenOrImplicit) =>
val normParams =
if params.head.mods.flags.is(Given) != mparam.mods.flags.is(Given) then
params.map: param =>
val normFlags = param.mods.flags &~ GivenOrImplicit | (mparam.mods.flags & (GivenOrImplicit))
param.withMods(param.mods.withFlags(normFlags))
.showing(i"adapted param $result ${result.mods.flags} for ${meth.name}", Printers.desugar)
else params
(normParams ++ mparams) :: Nil
case mparams :: mparamss1 =>
mparams :: recur(mparamss1)
case Nil =>
params :: Nil

cpy.DefDef(meth)(paramss = recur(meth.paramss))
if meth.removeAttachment(PolyFunctionApply).isDefined then
// for PolyFunctions we are limited to a single term param list, so we
// reuse the fitEvidenceParams logic to compute the new parameter lists
// and then we add the other parameter lists as function types to the
// return type
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
if ctx.mode.is(Mode.Type) then
cpy.DefDef(meth)(paramss = paramsFst, tpt = functionsOf(paramsSnd, meth.tpt))
else
cpy.DefDef(meth)(paramss = paramsFst, rhs = functionsOf(paramsSnd, meth.rhs))
else
val (paramsFst, paramsSnd) = fitParams(meth.paramss)
cpy.DefDef(meth)(paramss = paramsFst ++ paramsSnd)
end addEvidenceParams

/** The parameters generated from the contextual bounds of `meth`, as generated by `desugar.defDef` */
Expand Down Expand Up @@ -1224,27 +1298,29 @@ object desugar {
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) = tree: @unchecked
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree).withFlags(Synthetic)
)).withSpan(tree.span)
def makePolyFunctionType(tree: PolyFunction)(using Context): RefinedTypeTree = (tree: @unchecked) match
case PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ untpd.Function(vparamTypes, res)) =>
val paramFlags = fun match
case fun: FunctionWithMods =>
// TODO: make use of this in the desugaring when pureFuns is enabled.
// val isImpure = funFlags.is(Impure)

// Function flags to be propagated to each parameter in the desugared method type.
val givenFlag = fun.mods.flags.toTermFlags & Given
fun.erasedParams.map(isErased => if isErased then givenFlag | Erased else givenFlag)
case _ =>
vparamTypes.map(_ => EmptyFlags)

val vparams = vparamTypes.lazyZip(paramFlags).zipWithIndex.map {
case ((p: ValDef, paramFlags), n) => p.withAddedFlags(paramFlags)
case ((p, paramFlags), n) => makeSyntheticParameter(n + 1, p).withAddedFlags(paramFlags)
}.toList

RefinedTypeTree(ref(defn.PolyFunctionType), List(
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
.withFlags(Synthetic)
.withAttachment(PolyFunctionApply, ())
)).withSpan(tree.span)
end makePolyFunctionType

/** Invent a name for an anonympus given of type or template `impl`. */
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3429,7 +3429,7 @@ object Parsers {
*
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
* TypTypeParam ::= {Annotation}
* (id | ‘_’) [HkTypeParamClause] TypeBounds
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
*
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]
Expand Down Expand Up @@ -3460,7 +3460,9 @@ object Parsers {
else ident().toTypeName
val hkparams = typeParamClauseOpt(ParamOwner.Hk)
val bounds =
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name) else typeBounds()
if paramOwner.acceptsCtxBounds then typeAndCtxBounds(name)
else if in.featureEnabled(Feature.modularity) && paramOwner == ParamOwner.Type then typeAndCtxBounds(name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need the modularity import? isn't it part of SIP-64 so should be enabled without a feature?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I assumed that, since it was meant to be pushed to the next minor, it should be under a language import.
If it's not, can we still do another RC/backport?
@WojciechMazur

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should fix that in the followup and backport to RC2.
This PR was explicitly marked as one of the things that should be included in 3.6

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else typeBounds()
TypeDef(name, lambdaAbstract(hkparams, bounds)).withMods(mods)
}
}
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)
else typedPolyFunctionValue(desugar.elimContextBounds(tree1).asInstanceOf[untpd.PolyFunction], pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Expand Down Expand Up @@ -2474,7 +2474,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
val refinements1 = impl.body
val seen = mutable.Set[Symbol]()
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
typr.println(s"adding refinement $refinement")
checkRefinementNonCyclic(refinement, refineCls, seen)
val rsym = refinement.symbol
Expand All @@ -2488,7 +2488,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val member = refineCls.info.member(rsym.name)
if (member.isOverloaded)
report.error(OverloadInRefinement(rsym), refinement.srcPos)
}
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
}

Expand Down
Loading
Loading