Skip to content

Commit

Permalink
Fix ParamSpec inference against TypeVarTuple (#17431)
Browse files Browse the repository at this point in the history
Fixes #17278
Fixes #17127
  • Loading branch information
ilevkivskyi authored Jun 24, 2024
1 parent 620e281 commit 6c1d867
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 13 deletions.
6 changes: 5 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,11 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# (with literal '...').
if not template.is_ellipsis_args:
unpack_present = find_unpack_in_list(template.arg_types)
if unpack_present is not None:
# When both ParamSpec and TypeVarTuple are present, things become messy
# quickly. For now, we only allow ParamSpec to "capture" TypeVarTuple,
# but not vice versa.
# TODO: infer more from prefixes when possible.
if unpack_present is not None and not cactual.param_spec():
# We need to re-normalize args to the form they appear in tuples,
# for callables we always pack the suffix inside another tuple.
unpack = template.arg_types[unpack_present]
Expand Down
14 changes: 13 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
repl = self.variables.get(t.id, t)
if isinstance(repl, TypeVarTupleType):
return repl
elif isinstance(repl, ProperType) and isinstance(repl, (AnyType, UninhabitedType)):
# Some failed inference scenarios will try to set all type variables to Never.
# Instead of being picky and require all the callers to wrap them,
# do this here instead.
# Note: most cases when this happens are handled in expand unpack below, but
# in rare cases (e.g. ParamSpec containing Unpack star args) it may be skipped.
return t.tuple_fallback.copy_modified(args=[repl])
raise NotImplementedError

def visit_unpack_type(self, t: UnpackType) -> Type:
Expand Down Expand Up @@ -348,7 +355,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
# the replacement is ignored.
if isinstance(repl, Parameters):
# We need to expand both the types in the prefix and the ParamSpec itself
return t.copy_modified(
expanded = t.copy_modified(
arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types,
arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds,
arg_names=t.arg_names[:-2] + repl.arg_names,
Expand All @@ -358,6 +365,11 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
variables=[*repl.variables, *t.variables],
)
var_arg = expanded.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
# Sometimes we get new unpacks after expanding ParamSpec.
expanded.normalize_trivial_unpack()
return expanded
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
# changes, e.g. substitute Concatenate[int, P] in place of Q.
Expand Down
12 changes: 2 additions & 10 deletions mypy/semanal_typeargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE
from mypy.messages import format_type
from mypy.mixedtraverser import MixedTraverserVisitor
from mypy.nodes import ARG_STAR, Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile
from mypy.nodes import Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile
from mypy.options import Options
from mypy.scope import Scope
from mypy.subtypes import is_same_type, is_subtype
Expand Down Expand Up @@ -104,15 +104,7 @@ def visit_tuple_type(self, t: TupleType) -> None:

def visit_callable_type(self, t: CallableType) -> None:
super().visit_callable_type(t)
# Normalize trivial unpack in var args as *args: *tuple[X, ...] -> *args: X
if t.is_var_arg:
star_index = t.arg_kinds.index(ARG_STAR)
star_type = t.arg_types[star_index]
if isinstance(star_type, UnpackType):
p_type = get_proper_type(star_type.type)
if isinstance(p_type, Instance):
assert p_type.type.fullname == "builtins.tuple"
t.arg_types[star_index] = p_type.args[0]
t.normalize_trivial_unpack()

def visit_instance(self, t: Instance) -> None:
super().visit_instance(t)
Expand Down
13 changes: 12 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,17 @@ def param_spec(self) -> ParamSpecType | None:
prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2])
return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix)

def normalize_trivial_unpack(self) -> None:
# Normalize trivial unpack in var args as *args: *tuple[X, ...] -> *args: X in place.
if self.is_var_arg:
star_index = self.arg_kinds.index(ARG_STAR)
star_type = self.arg_types[star_index]
if isinstance(star_type, UnpackType):
p_type = get_proper_type(star_type.type)
if isinstance(p_type, Instance):
assert p_type.type.fullname == "builtins.tuple"
self.arg_types[star_index] = p_type.args[0]

def with_unpacked_kwargs(self) -> NormalizedCallableType:
if not self.unpack_kwargs:
return cast(NormalizedCallableType, self)
Expand Down Expand Up @@ -2113,7 +2124,7 @@ def with_normalized_var_args(self) -> Self:
if not isinstance(unpacked, TupleType):
# Note that we don't normalize *args: *tuple[X, ...] -> *args: X,
# this should be done once in semanal_typeargs.py for user-defined types,
# and we ourselves should never construct such type.
# and we ourselves rarely construct such type.
return self
unpack_index = find_unpack_in_list(unpacked.items)
if unpack_index == 0 and len(unpacked.items) > 1:
Expand Down
53 changes: 53 additions & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -2407,3 +2407,56 @@ reveal_type(x) # N: Revealed type is "__main__.C[builtins.str, builtins.int]"
reveal_type(C(f)) # N: Revealed type is "__main__.C[builtins.str, builtins.int, builtins.int, builtins.int, builtins.int]"
C[()] # E: At least 1 type argument(s) expected, none given
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleAgainstParamSpecActualSuccess]
from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List
from typing_extensions import ParamSpec

R = TypeVar("R")
P = ParamSpec("P")

class CM(Generic[R]): ...
def cm(fn: Callable[P, R]) -> Callable[P, CM[R]]: ...

Ts = TypeVarTuple("Ts")
@cm
def test(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: ...

reveal_type(test) # N: Revealed type is "def [Ts] (*args: Unpack[Ts`-1]) -> __main__.CM[Tuple[Unpack[Ts`-1]]]"
reveal_type(test(1, 2, 3)) # N: Revealed type is "__main__.CM[Tuple[Literal[1]?, Literal[2]?, Literal[3]?]]"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleAgainstParamSpecActualFailedNoCrash]
from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List
from typing_extensions import ParamSpec

R = TypeVar("R")
P = ParamSpec("P")

class CM(Generic[R]): ...
def cm(fn: Callable[P, List[R]]) -> Callable[P, CM[R]]: ...

Ts = TypeVarTuple("Ts")
@cm # E: Argument 1 to "cm" has incompatible type "Callable[[VarArg(Unpack[Ts])], Tuple[Unpack[Ts]]]"; expected "Callable[[VarArg(Never)], List[Never]]"
def test(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: ...

reveal_type(test) # N: Revealed type is "def (*args: Never) -> __main__.CM[Never]"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleAgainstParamSpecActualPrefix]
from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List
from typing_extensions import ParamSpec, Concatenate

R = TypeVar("R")
P = ParamSpec("P")
T = TypeVar("T")

class CM(Generic[R]): ...
def cm(fn: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[List[T], P], CM[R]]: ...

Ts = TypeVarTuple("Ts")
@cm
def test(x: T, *args: Unpack[Ts]) -> Tuple[T, Unpack[Ts]]: ...

reveal_type(test) # N: Revealed type is "def [T, Ts] (builtins.list[T`2], *args: Unpack[Ts`-2]) -> __main__.CM[Tuple[T`2, Unpack[Ts`-2]]]"
[builtins fixtures/tuple.pyi]

0 comments on commit 6c1d867

Please sign in to comment.