diff --git a/guppylang/std/_internal/checker.py b/guppylang/std/_internal/checker.py index 4166cded..55d5613f 100644 --- a/guppylang/std/_internal/checker.py +++ b/guppylang/std/_internal/checker.py @@ -29,6 +29,7 @@ from guppylang.nodes import ( DesugaredArrayComp, DesugaredGeneratorExpr, + GenericParamValue, GlobalCall, MakeIter, ResultExpr, @@ -360,12 +361,23 @@ class RangeChecker(CustomCallChecker): def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: check_num_args(1, len(args), self.node) [stop] = args - stop, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument") - range_iter, range_ty = self.make_range(stop) - if isinstance(stop, ast.Constant): - return to_sized_iter(range_iter, range_ty, stop.value, self.ctx) + stop_checked, _ = ExprChecker(self.ctx).check(stop, int_type(), "argument") + range_iter, range_ty = self.make_range(stop_checked) + # Check if `stop` is a statically known value. Note that we need to do this on + # the original `stop` instead of `stop_checked` to avoid any previously inserted + # `int` coercions. + if (static_stop := self.check_static(stop)) is not None: + return to_sized_iter(range_iter, range_ty, static_stop, self.ctx) return range_iter, range_ty + def check_static(self, stop: ast.expr) -> "int | Const | None": + stop, _ = ExprSynthesizer(self.ctx).synthesize(stop, allow_free_vars=True) + if isinstance(stop, ast.Constant) and isinstance(stop.value, int): + return stop.value + if isinstance(stop, GenericParamValue) and stop.param.ty == nat_type(): + return stop.param.to_bound().const + return None + def range_ty(self) -> StructType: from guppylang.std.builtins import Range @@ -382,7 +394,7 @@ def make_range(self, stop: ast.expr) -> tuple[ast.expr, Type]: def to_sized_iter( - iterator: ast.expr, range_ty: Type, size: int, ctx: Context + iterator: ast.expr, range_ty: Type, size: "int | Const", ctx: Context ) -> tuple[ast.expr, Type]: """Adds a static size annotation to an iterator.""" sized_iter_ty = sized_iter_type(range_ty, size) diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py index 2babd982..092db623 100644 --- a/tests/integration/test_range.py +++ b/tests/integration/test_range.py @@ -43,3 +43,15 @@ def negative() -> SizedIter[Range, 10]: return range(10) validate(module.compile()) + + +def test_static_generic_size(validate): + module = GuppyModule("test") + n = guppy.nat_var("n", module=module) + + @guppy(module) + def negative() -> SizedIter[Range, n]: + return range(n) + + validate(module.compile()) +