Skip to content

Commit

Permalink
feat: Allow generic nat args in statically sized ranges (#706)
Browse files Browse the repository at this point in the history
Closes #663
  • Loading branch information
mark-koch authored Dec 12, 2024
1 parent 6ef6d60 commit f441bb8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
22 changes: 17 additions & 5 deletions guppylang/std/_internal/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from guppylang.nodes import (
DesugaredArrayComp,
DesugaredGeneratorExpr,
GenericParamValue,
GlobalCall,
MakeIter,
ResultExpr,
Expand Down Expand Up @@ -360,12 +361,23 @@ class RangeChecker(CustomCallChecker):
def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]:
check_num_args(1, len(args), self.node)
[stop] = args
stop, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
range_iter, range_ty = self.make_range(stop)
if isinstance(stop, ast.Constant):
return to_sized_iter(range_iter, range_ty, stop.value, self.ctx)
stop_checked, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument")
range_iter, range_ty = self.make_range(stop_checked)
# Check if `stop` is a statically known value. Note that we need to do this on
# the original `stop` instead of `stop_checked` to avoid any previously inserted
# `int` coercions.
if (static_stop := self.check_static(stop)) is not None:
return to_sized_iter(range_iter, range_ty, static_stop, self.ctx)
return range_iter, range_ty

def check_static(self, stop: ast.expr) -> "int | Const | None":
stop, _ = ExprSynthesizer(self.ctx).synthesize(stop, allow_free_vars=True)
if isinstance(stop, ast.Constant) and isinstance(stop.value, int):
return stop.value
if isinstance(stop, GenericParamValue) and stop.param.ty == nat_type():
return stop.param.to_bound().const
return None

def range_ty(self) -> StructType:
from guppylang.std.builtins import Range

Expand All @@ -382,7 +394,7 @@ def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]:


def to_sized_iter(
iterator: ast.expr, range_ty: Type, size: int, ctx: Context
iterator: ast.expr, range_ty: Type, size: "int | Const", ctx: Context
) -> tuple[ast.expr, Type]:
"""Adds a static size annotation to an iterator."""
sized_iter_ty = sized_iter_type(range_ty, size)
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,15 @@ def negative() -> SizedIter[Range, 10]:
return range(10)

validate(module.compile())


def test_static_generic_size(validate):
module = GuppyModule("test")
n = guppy.nat_var("n", module=module)

@guppy(module)
def negative() -> SizedIter[Range, n]:
return range(n)

validate(module.compile())

0 comments on commit f441bb8

Please sign in to comment.