Skip to content

Commit

Permalink
[SignatureTyping] Fix signature typing
Browse files Browse the repository at this point in the history
Header term-solving should be delayed until the body is elaborated.

Also fix broken imports.
  • Loading branch information
tgeng committed Nov 11, 2024
1 parent e401c24 commit 4c874d0
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 91 deletions.
99 changes: 51 additions & 48 deletions src/main/scala/com/github/tgeng/archon/core/ir/elaboration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ import com.github.tgeng.archon.core.ir.CTerm.*
import com.github.tgeng.archon.core.ir.CaseTree.*
import com.github.tgeng.archon.core.ir.CoPattern.*
import com.github.tgeng.archon.core.ir.Declaration.*
import com.github.tgeng.archon.core.ir.DeclarationPart.*
import com.github.tgeng.archon.core.ir.IrError.*
import com.github.tgeng.archon.core.ir.Pattern.*
import com.github.tgeng.archon.core.ir.PreDeclaration.*
import com.github.tgeng.archon.core.ir.SourceInfo.*
import com.github.tgeng.archon.core.ir.UnificationResult.*
import com.github.tgeng.archon.core.ir.VTerm.*
import com.github.tgeng.archon.core.ir.unifyAll

import scala.collection.immutable.SeqMap
import scala.collection.mutable
Expand Down Expand Up @@ -55,8 +57,6 @@ def elaborate
enum DeclarationPart:
case HEAD, BODY

import com.github.tgeng.archon.core.ir.DeclarationPart.* import com.github.tgeng.archon.core.ir.unifyAll

@throws(classOf[IrError])
def sortPreDeclarations
(declarations: Seq[PreDeclaration])
Expand Down Expand Up @@ -192,17 +192,15 @@ private def elaborateDataHead
case _ => throw NotDataTypeType(ty)

val (tIndices, level) = elaborateTy(preData.ty)
// level and eqDecidability should not depend on index arguments
// level should not depend on index arguments
val strengthenedLevel =
try level.strengthen(tIndices.size, 0)
catch case e: StrengthenException => throw DataLevelCannotDependOnIndexArgument(preData)
val data = checkData(
Data(
preData.qn,
Γ.zip(Iterator.continually(Variance.INVARIANT)) ++ tParamTys,
tIndices,
strengthenedLevel,
),
val data = Data(
preData.qn,
Γ.zip(Iterator.continually(Variance.INVARIANT)) ++ tParamTys,
tIndices,
strengthenedLevel,
)
Σ.addDeclaration(data)

Expand All @@ -216,12 +214,12 @@ private def elaborateDataBody

@throws(classOf[IrError])
def elaborateTy
(ty: CTerm)
(ty: CTerm, level: VTerm)
(using Γ: Context)
(using Signature)
(using TypingContext)
: (Telescope, /* constructor tArgs */ List[VTerm]) =
checkIsCType(ty).normalized(None) match
checkIsCType(ty, Some(level)).normalized(None) match
// Here and below we do not care the declared effect types because data type constructors
// are always total. Declaring non-total signature is not necessary (nor desirable) but
// acceptable.
Expand All @@ -234,23 +232,25 @@ private def elaborateDataBody
(Nil, args.drop(data.context.size))
case F(t, _, _) => throw ExpectDataType(t, Some(data.qn))
case FunctionType(binding, bodyTy, _, _) =>
val (telescope, level) = elaborateTy(bodyTy)(using Γ :+ binding)
(binding :: telescope, level)
val (telescope, args) = elaborateTy(bodyTy, level.weakened)(using Γ :+ binding)
(binding :: telescope, args)
case _ => throw NotDataTypeType(ty)

// number of index arguments
given Context = data.context.map(_._1)

ctx.trace(s"elaborating data body ${preData.qn}"):
preData.constructors.foldLeft[Signature](Σ) { case (_Σ, constructor) =>
given Signature =
ctx.trace(s"elaborating constructor ${constructor.name}"):
val ty = constructor.ty
val (paramTys, tArgs) = elaborateTy(ty)
val con =
checkDataConstructor(preData.qn, Constructor(constructor.name, paramTys, tArgs))
_Σ.addConstructor(preData.qn, con)
}
preData.constructors
.foldLeft[Signature](Σ) { case (_Σ, constructor) =>
given Signature =
ctx.trace(s"elaborating constructor ${constructor.name}"):
val ty = constructor.ty
val (paramTys, tArgs) = elaborateTy(ty, data.level)
val con =
checkDataConstructor(preData.qn, Constructor(constructor.name, paramTys, tArgs))
_Σ.addConstructor(preData.qn, con)
}
.replaceDeclaration(checkData(data))

@throws(classOf[IrError])
private def elaborateRecordHead
Expand Down Expand Up @@ -280,7 +280,7 @@ private def elaborateRecordHead
),
)
case t => throw ExpectCType(t)
Σ.addDeclaration(checkRecord(r))
Σ.addDeclaration(r)

@throws(classOf[IrError])
private def elaborateRecordBody
Expand All @@ -297,13 +297,16 @@ private def elaborateRecordBody
record.selfBinding.name,
)

val level = record.level.weakened // weakened for self
ctx.trace(s"elaborating record body ${preRecord.qn}"):
preRecord.fields.foldLeft[Signature](Σ) { case (_Σ, field) =>
ctx.trace(s"elaborating field ${field.name}"):
val ty = checkIsCType(field.ty).normalized(None)
val f = checkRecordField(preRecord.qn, Field(field.name, ty))
_Σ.addField(preRecord.qn, f)
}
preRecord.fields
.foldLeft[Signature](Σ) { case (_Σ, field) =>
ctx.trace(s"elaborating field ${field.name}"):
val ty = checkIsCType(field.ty, Some(level)).normalized(None)
val f = checkRecordField(preRecord.qn, Field(field.name, ty))
_Σ.addField(preRecord.qn, f)
}
.replaceDeclaration(checkRecord(record))

@throws(classOf[IrError])
private def elaborateDefHead
Expand All @@ -323,7 +326,7 @@ private def elaborateDefHead
given Context = newEContext.map(_._1)
val ty = checkIsCType(definition.ty).normalized(None)
val d: Definition = Definition(definition.qn, newEContext, ty)
Σ.addDeclaration(checkDef(d))
Σ.addDeclaration(d)

@throws(classOf[IrError])
private def elaborateDefBody
Expand Down Expand Up @@ -690,7 +693,6 @@ private def elaborateDefBody
val ρ2t = ρ2.toTermSubstitutor
given ΓSplit: Context =
1 ++ Δ.subst(ρ.toTermSubstitutor) ++2.subst(ρ1t)
ctx.debug(ΓSplit)

val newProblem = subst(problem, ρ2t)
if newProblem.isEmpty then
Expand Down Expand Up @@ -785,12 +787,12 @@ private def elaborateDefBody
},
)
_Σ.addCaseTree(preDefinition.qn, _Q)
// We didn't solve the meta-variables inside the definition type until now because we want to
// have the call-site context ready before solving them. Delay solving during body elaboration
// works as long as lambda definitions are APPENDED after the call-site declaration. This is
// because our topological sort preserves original order of definitions when possible. Hence,
// the call-site function body would be elaborated before the lambda body.
_Σ.replaceDeclaration(definition.copy(ty = ctx.solveTerm(definition.ty)))
// We didn't solve the meta-variables inside the definition type until now because we want to
// have the call-site context ready before solving them. Delay solving during body elaboration
// works as long as lambda definitions are APPENDED after the call-site declaration. This is
// because our topological sort preserves original order of definitions when possible. Hence,
// the call-site function body would be elaborated before the lambda body.
.replaceDeclaration(checkDef(definition))

@throws(classOf[IrError])
private def elaborateEffectHead
Expand All @@ -808,7 +810,7 @@ private def elaborateEffectHead
case Return(continuationUsage, _) => continuationUsage
case c => throw ExpectReturnAValue(c)
val e: Effect = Effect(effect.qn, Γ2, continuationUsage)
Σ.addDeclaration(checkEffect(e))
Σ.addDeclaration(e)

@throws(classOf[IrError])
private def elaborateEffectBody
Expand All @@ -831,20 +833,21 @@ private def elaborateEffectBody
// Here and below we do not care the declared effect types because data type constructors
// are always total. Declaring non-total signature is not necessary (nor desirable) but
// acceptable.
case F(ty, _, usage) =>
(Nil, ty, usage)
case F(ty, _, usage) => (Nil, ty, usage.normalized)
case FunctionType(binding, bodyTy, _, _) =>
val (telescope, level, usage) = elaborateTy(bodyTy)(using Γ :+ binding)
(binding :: telescope, level, usage)
case _ => throw ExpectFType(ty)

preEffect.operations.foldLeft[Signature](Σ) { case (_Σ, operation) =>
given Signature =
ctx.trace(s"elaborating operation ${operation.name}"):
val (paramTys, resultTy, usage) = elaborateTy(operation.ty)
val o = Operation(operation.name, paramTys, resultTy, usage)
_Σ.addOperation(effect.qn, checkOperation(effect.qn, o))
}
preEffect.operations
.foldLeft[Signature](Σ) { case (_Σ, operation) =>
given Signature =
ctx.trace(s"elaborating operation ${operation.name}"):
val (paramTys, resultTy, usage) = elaborateTy(operation.ty)
val o = Operation(operation.name, paramTys, resultTy, usage)
_Σ.addOperation(effect.qn, checkOperation(effect.qn, o))
}
.replaceDeclaration(checkEffect(effect))

@throws(classOf[IrError])
private def elaborateTTelescope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ enum Declaration:
(
qn: QualifiedName,
context: TContext,
/* binding: + context + tIndexTys */
/* binding += context */
level: VTerm,
selfBinding: Binding[VTerm],
)
Expand Down Expand Up @@ -56,11 +56,11 @@ case class Constructor
tArgs: Arguments = Nil, /* binding += context + paramTys */
)

case class Field(name: Name, /* + tParamTys + 1 for self */ ty: CTerm)
case class Field(name: Name, /* binding += context + 1 for self */ ty: CTerm)

case class Clause
(
// contains def.context
// contains def.context + elaborated variables from lhs co-patterns
context: Context,
lhs: List[CoPattern], /* bindings += clause.context */
rhs: CTerm, /* bindings += clause.context */
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.github.tgeng.archon.core.ir

import com.github.tgeng.archon.common.*
import com.github.tgeng.archon.common.IndentPolicy.*
import com.github.tgeng.archon.core.common.*
import com.github.tgeng.archon.core.ir
import com.github.tgeng.archon.core.ir.CTerm.*
Expand All @@ -18,13 +17,12 @@ def checkData(data: Data)(using Σ: Signature)(using ctx: TypingContext): Data =
given Context = IndexedSeq()
ctx.trace(s"checking data signature ${data.qn}"):
val tParamsTysTelescope =
ctx.solveTerm(checkParameterTyDeclarations(data.context.map(_._1).toTelescope))
ctx.solveTerm(data.context.map(_._1).toTelescope)
{
given tParamTys: Context = Context.fromTelescope(tParamsTysTelescope)
val tIndexTys =
ctx.solveTerm(checkParameterTyDeclarations(data.tIndexTys))
val tIndexTys = ctx.solveTerm(data.tIndexTys)
checkTParamsAreUnrestricted((tParamTys ++ tIndexTys).toTelescope)
val level = ctx.solveTerm(checkLevel(data.level))
val level = ctx.solveTerm(data.level)
Data(data.qn, tParamTys.zip(data.context.map(_._2)), tIndexTys, level)
}

Expand All @@ -39,7 +37,7 @@ def checkDataConstructor
case Some(data) =>
given Γ: Context = data.context.map(_._1)
ctx.trace(s"checking data constructor $qn.${con.name}"):
val paramTys = ctx.solveTerm(checkParameterTyDeclarations(con.paramTys, Some(data.level)))
val paramTys = ctx.solveTerm(con.paramTys)
val tArgsContext = Γ ++ paramTys
val tArgs =
checkTypes(con.tArgs, data.tIndexTys.weaken(con.paramTys.size, 0))(using tArgsContext)
Expand All @@ -54,7 +52,7 @@ def checkRecord(record: Record)(using Σ: Signature)(using ctx: TypingContext):
given Context = IndexedSeq()
ctx.trace(s"checking record signature ${record.qn}"):
val tParams = record.context.map(_._1)
val tParamTysTelescope = ctx.solveTerm(checkParameterTyDeclarations(tParams.toList))
val tParamTysTelescope = ctx.solveTerm(tParams.toList)
{
given tParamTys: Context = Context.fromTelescope(tParamTysTelescope)
checkTParamsAreUnrestricted(tParamTysTelescope)
Expand All @@ -81,7 +79,7 @@ def checkRecordField
given Context = record.context.map(_._1).toIndexedSeq :+ record.selfBinding

ctx.trace(s"checking record field $qn.${field.name}"):
val ty = ctx.solveTerm(checkIsCType(field.ty, Some(record.level.weakened)))
val ty = ctx.solveTerm(field.ty)
val violatingVars =
// 1 is to offset self binding.
VarianceChecker.visitCTerm(field.ty)(using record.context, Variance.COVARIANT, 1)
Expand Down Expand Up @@ -210,24 +208,15 @@ private object VarianceChecker extends Visitor[(TContext, Variance, Nat), Seq[Va

@throws(classOf[IrError])
def checkDef(definition: Definition)(using Signature)(using ctx: TypingContext): Definition =
// TODO[P0]: do escape analysis here. Make sure to automatically derive it for null escape status
given Context = definition.context.map(_._1)
ctx.trace(s"checking def signature ${definition.qn}"):
val ty = checkIsCType(definition.ty)
Definition(definition.qn, definition.context, ty)
definition.copy(ty = ctx.solveTerm(definition.ty)(using definition.context.map(_._1)))

@throws(classOf[IrError])
def checkEffect(effect: Effect)(using Signature)(using ctx: TypingContext): Effect =
given Context = Context.empty
ctx.trace(s"checking effect signature ${effect.qn}"):
val telescope = ctx.solveTerm(checkParameterTyDeclarations(effect.context.toTelescope))
checkTParamsAreUnrestricted(telescope)

{
given Γ: Context = Context.fromTelescope(telescope)
val continuationUsage = ctx.solveTerm(checkType(effect.continuationUsage, UsageType()))
Effect(effect.qn, Γ, continuationUsage)
}
given Γ: Context = Context.fromTelescope(effect.context.toTelescope)
val continuationUsage = ctx.solveTerm(effect.continuationUsage)
Effect(effect.qn, Γ, continuationUsage)

@throws(classOf[IrError])
def checkOperation
Expand All @@ -241,12 +230,11 @@ def checkOperation
given Γ: Context = effect.context

ctx.trace(s"checking effect operation $qn.${operation.name}"):
val paramTys = ctx.solveTerm(checkParameterTyDeclarations(operation.paramTys))
val paramTys = ctx.solveTerm(operation.paramTys)
{
given Context = Γ ++ paramTys
val resultTy = ctx.solveTerm(checkIsType(operation.resultTy))
val resultUsage =
ctx.solveTerm(checkType(operation.resultUsage, UsageType(None)))
val resultTy = ctx.solveTerm(operation.resultTy)
val resultUsage = ctx.solveTerm(operation.resultUsage)
Operation(
operation.name,
paramTys,
Expand All @@ -270,16 +258,3 @@ private def checkTParamsAreUnrestricted
ExpectUnrestrictedTypeParameterBinding(binding),
)
checkTParamsAreUnrestricted(rest)(using Γ :+ binding)

@throws(classOf[IrError])
private def checkParameterTyDeclarations
(tParamTys: Telescope, levelBound: Option[VTerm] = None)
(using Γ: Context)
(using Σ: Signature)
(using TypingContext)
: Telescope = tParamTys match
case Nil => Nil
case binding :: rest =>
val ty = checkIsType(binding.ty, levelBound)
val usage = checkType(binding.usage, UsageType(None))
Binding(ty, usage)(binding.name) :: checkParameterTyDeclarations(rest)(using Γ :+ binding)
2 changes: 1 addition & 1 deletion src/main/scala/com/github/tgeng/archon/core/ir/term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ enum VTerm(val sourceInfo: SourceInfo) extends SourceInfoOwner[VTerm]:
case Con(name: Name, args: Arguments = Nil)(using sourceInfo: SourceInfo)
extends VTerm(sourceInfo)

// Note, `upper` here is in the sense of typing subsumption, not the usage lattice. This is the
// Note, `upper` here is in the sense of typing subsumption, not the usage lattice. This is, the
// lower bound in the usage lattice. Hence Option.None is used to represent unbounded case, as the
// lattice is not bounded below. Note that the semantic of this `upperBound` is different from
// `continuationUsage`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ class TypingContext
if enableDebugging then
val indent = "" * traceLevel + " "
println(
indent + ANSI_CYAN + stringify(t) + " = " + verbosePPrinter
indent + ANSI_CYAN + "[DEBUG] " + stringify(t) + " = " + verbosePPrinter
.apply(t)
.toString
.replace("\n", "\n" + indent) + ANSI_RESET,
Expand Down

0 comments on commit 4c874d0

Please sign in to comment.