Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Concatenate and Generic with ParamSpec substitution #489

Merged
merged 22 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 131 additions & 2 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3705,6 +3705,10 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
self.assertEqual(Y.__parameters__, ())
self.assertEqual(Y.__args__, ((int, str, str), bytes, memoryview))

# Regression test; fixing #126 might cause an error here
with self.assertRaisesRegex(TypeError, "not a generic class"):
Y[int]

def test_protocol_generic_over_typevartuple(self):
Ts = TypeVarTuple("Ts")
T = TypeVar("T")
Expand Down Expand Up @@ -5259,6 +5263,7 @@ class X(Generic[T, P]):
class Y(Protocol[T, P]):
pass

things = "arguments" if sys.version_info >= (3, 10) else "parameters"
for klass in X, Y:
with self.subTest(klass=klass.__name__):
G1 = klass[int, P_2]
Expand All @@ -5273,20 +5278,118 @@ class Y(Protocol[T, P]):
self.assertEqual(G3.__args__, (int, Concatenate[int, ...]))
self.assertEqual(G3.__parameters__, ())

with self.assertRaisesRegex(
TypeError,
f"Too few {things} for {klass}"
):
klass[int]

# The following are some valid uses cases in PEP 612 that don't work:
# These do not work in 3.9, _type_check blocks the list and ellipsis.
# G3 = X[int, [int, bool]]
# G4 = X[int, ...]
# G5 = Z[[int, str, bool]]
# Not working because this is special-cased in 3.10.
# G6 = Z[int, str, bool]

def test_single_argument_generic(self):
P = ParamSpec("P")
T = TypeVar("T")
P_2 = ParamSpec("P_2")

class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

for klass in Z, ProtoZ:
with self.subTest(klass=klass.__name__):
# Note: For 3.10+ __args__ are nested tuples here ((int, ),) instead of (int, )
G6 = klass[int, str, T]
G6args = G6.__args__[0] if sys.version_info >= (3, 10) else G6.__args__
self.assertEqual(G6args, (int, str, T))
self.assertEqual(G6.__parameters__, (T,))

# P = [int]
G7 = klass[int]
G7args = G7.__args__[0] if sys.version_info >= (3, 10) else G7.__args__
self.assertEqual(G7args, (int,))
self.assertEqual(G7.__parameters__, ())

G8 = klass[Concatenate[T, ...]]
self.assertEqual(G8.__args__, (Concatenate[T, ...], ))
self.assertEqual(G8.__parameters__, (T,))

G9 = klass[Concatenate[T, P_2]]
self.assertEqual(G9.__args__, (Concatenate[T, P_2], ))

# This is an invalid form but useful for testing correct subsitution
G10 = klass[int, Concatenate[str, P]]
G10args = G10.__args__[0] if sys.version_info >= (3, 10) else G10.__args__
self.assertEqual(G10args, (int, Concatenate[str, P], ))

def test_single_argument_generic_with_parameter_expressions(self):
P = ParamSpec("P")
T = TypeVar("T")
P_2 = ParamSpec("P_2")

class Z(Generic[P]):
pass

class ProtoZ(Protocol[P]):
pass

things = "arguments" if sys.version_info >= (3, 10) else "parameters"
for klass in Z, ProtoZ:
with self.subTest(klass=klass.__name__):
G8 = klass[Concatenate[T, ...]]

H8_1 = G8[int]
self.assertEqual(H8_1.__parameters__, ())
with self.assertRaisesRegex(TypeError, "not a generic class"):
H8_1[str]

H8_2 = G8[T][int]
self.assertEqual(H8_2.__parameters__, ())
with self.assertRaisesRegex(TypeError, "not a generic class"):
H8_2[str]

G9 = klass[Concatenate[T, P_2]]
self.assertEqual(G9.__parameters__, (T, P_2))

with self.assertRaisesRegex(TypeError,
"The last parameter to Concatenate should be a ParamSpec variable or ellipsis."
if sys.version_info < (3, 10) else
# from __typing_subst__
"Expected a list of types, an ellipsis, ParamSpec, or Concatenate"
):
G9[int, int]

with self.assertRaisesRegex(TypeError, f"Too few {things}"):
G9[int]

with self.subTest("Check list as parameter expression", klass=klass.__name__):
if sys.version_info < (3, 10):
self.skipTest("Cannot pass non-types")
G5 = klass[[int, str, T]]
self.assertEqual(G5.__parameters__, (T,))
self.assertEqual(G5.__args__, ((int, str, T),))

H9 = G9[int, [T]]
self.assertEqual(H9.__parameters__, (T,))

# This is an invalid parameter expression but useful for testing correct subsitution
G10 = klass[int, Concatenate[str, P]]
with self.subTest("Check invalid form substitution"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a valid parameter expression, it is nice for debugging though, should keep or remove it?

self.assertEqual(G10.__parameters__, (P, ))
if sys.version_info < (3, 9):
self.skipTest("3.8 typing._type_subst does not support this substitution process")
H10 = G10[int]
if (3, 10) <= sys.version_info < (3, 11, 3):
self.skipTest("3.10-3.11.2 does not substitute Concatenate here")
self.assertEqual(H10.__parameters__, ())
H10args = H10.__args__[0] if sys.version_info >= (3, 10) else H10.__args__
self.assertEqual(H10args, (int, (str, int)))

def test_pickle(self):
global P, P_co, P_contra, P_default
P = ParamSpec('P')
Expand Down Expand Up @@ -5468,6 +5571,32 @@ def test_eq(self):
self.assertEqual(hash(C4), hash(C5))
self.assertNotEqual(C4, C6)

def test_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
Ts = TypeVarTuple("Ts")

C1 = Concatenate[str, T, ...]
self.assertEqual(C1[int], Concatenate[str, int, ...])

C2 = Concatenate[str, P]
self.assertEqual(C2[...], Concatenate[str, ...])
self.assertEqual(C2[int], (str, int))
U1 = Unpack[Tuple[int, str]]
U2 = Unpack[Ts]
self.assertEqual(C2[U1], (str, int, str))
self.assertEqual(C2[U2], (str, Unpack[Ts]))
self.assertEqual(C2["U2"], (str, typing.ForwardRef("U2")))

if (3, 12, 0) <= sys.version_info < (3, 12, 4):
with self.assertRaises(AssertionError):
C2[Unpack[U2]]
else:
with self.assertRaisesRegex(TypeError, "must be used with a tuple type"):
C2[Unpack[U2]]

C3 = Concatenate[str, T, P]
self.assertEqual(C3[int, [bool]], (str, int, bool))

class TypeGuardTests(BaseTestCase):
def test_basics(self):
Expand Down
Loading
Loading