Skip to content

Commit

Permalink
Fix narrowing unions in tuple pattern matches
Browse files Browse the repository at this point in the history
Allows sequence pattern matching of tuple types with union members to
use an unmatched pattern to narrow the possible type of union member to
the type that was not matched in the sequence pattern.

For example given a type of `tuple[int, int | None]` a pattern match
like:

```
match tuple_type:
    case a, None:
         return
    case t:
         reveal_type(t) # narrows tuple type to tuple[int, int]
         return
```

The case ..., None sequence pattern match can now narrow the type in
further match statements to rule out the None side of the union.

This is implemented by moving the original implementation of tuple
narrowing in sequence pattern matching since its functionality should be
to the special case where a tuple only has length of one. This
implementation does not hold for tuples of length greater than one since
it does not account all combinations of alternative types.

This replace that implementation with a new one that  builds the rest
type by iterating over the potential rest type members preserving
narrowed types if they are available and replacing any uninhabited types
with the original type of tuple member since these matches are only
exhaustive if all members of the tuple are matched.

Fixes python#14731
  • Loading branch information
edpaget committed Apr 9, 2024
1 parent 8019010 commit 8454267
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
48 changes: 33 additions & 15 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,24 +302,42 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
new_type: Type
rest_type: Type = current_type
if isinstance(current_type, TupleType) and unpack_index is None:
narrowed_inner_types = []
inner_rest_types = []
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
(narrowed_inner_type, inner_rest_type) = (
self.chk.conditional_types_with_intersection(
new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type
)
)
narrowed_inner_types.append(narrowed_inner_type)
inner_rest_types.append(inner_rest_type)
if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
if all(not is_uninhabited(typ) for typ in new_inner_types):
new_type = TupleType(new_inner_types, current_type.partial_fallback)
else:
new_type = UninhabitedType()

if all(is_uninhabited(typ) for typ in inner_rest_types):
# All subpatterns always match, so we can apply negative narrowing
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
if all(is_uninhabited(typ) for typ in rest_inner_types):
# If all types are uninhabited there is no other pattern that can
# match this tuple
rest_type = UninhabitedType()
elif any(is_uninhabited(typ) for typ in rest_inner_types):
# If at least one rest type is uninhabited the rest type can be narrowed
narrowed_types: list[Type] = []
for inner_type, rest_type in zip(inner_types, rest_inner_types):
# if the narrowed rest type is Uninhabited that means that that the next
# pattern could match any of the original inner types of the tuple.
if is_uninhabited(rest_type):
narrowed_types.append(inner_type)
else:
narrowed_types.append(rest_type)
rest_type = TupleType(narrowed_types, current_type.partial_fallback)
elif len(rest_inner_types) == 1:
# Otherwise we need can apply negative narrowing if the alternative type
# is totally disjoint from the pattern narrowed type. And there is only
# one field in the tuple.
narrowed_inner_types = []
inner_rest_types = []
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
(narrowed_inner_type, inner_rest_type) = (
self.chk.conditional_types_with_intersection(
new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type
)
)
narrowed_inner_types.append(narrowed_inner_type)
inner_rest_types.append(inner_rest_type)
if all(is_uninhabited(typ) for typ in inner_rest_types):
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
elif isinstance(current_type, TupleType):
# For variadic tuples it is too tricky to match individual items like for fixed
# tuples, so we instead try to narrow the entire type.
Expand Down
60 changes: 60 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,66 @@ match m, (n, o):
pass
[builtins fixtures/tuple.pyi]

[case testPartialMatchTuplePatternNarrowsUnion]
from typing import Tuple, Union

m: Tuple[Union[int, None], Union[int, None]]

match m:
case a, None:
reveal_type(a) # N: Revealed type is "Union[builtins.int, None]"
reveal_type(m) # N: Revealed type is "Tuple[Union[builtins.int, None], None]"
case None, b:
reveal_type(b) # N: Revealed type is "builtins.int"
reveal_type(m) # N: Revealed type is "Tuple[None, builtins.int]"
case t:
reveal_type(t) # N: Revealed type is "Tuple[builtins.int, builtins.int]"
[builtins fixtures/tuple.pyi]

[case testPartialMatchTuplePatternNarrowsUnionWithFullMatch]
from typing import Tuple, Union

m: Tuple[Union[int, None], Union[int, None]]

match m:
case None, None:
reveal_type(m) # N: Revealed type is "Tuple[None, None]"
case a, None:
reveal_type(a) # N: Revealed type is "Union[builtins.int, None]"
reveal_type(m) # N: Revealed type is "Tuple[Union[builtins.int, None], None]"
case None, b:
reveal_type(b) # N: Revealed type is "builtins.int"
reveal_type(m) # N: Revealed type is "Tuple[None, builtins.int]"
case t:
reveal_type(t) # N: Revealed type is "Tuple[builtins.int, builtins.int]"
[builtins fixtures/tuple.pyi]

[case testPartialMatchTuplePatternNarrowsUnionWithUninhibitedMatch]
from typing import Tuple, Union

m: Tuple[Union[int, None], Union[int, None]]

match m:
case a, b:
reveal_type(m) # N: Revealed type is "Tuple[Union[builtins.int, None], Union[builtins.int, None]]"
case t:
reveal_type(t)
[builtins fixtures/tuple.pyi]

[case testMatchTuplePatterNarrowsWithOrMatch]
from typing import Tuple, Union

m: Tuple[Union[int, None], Union[int, None]]

match m:
case (None, _) | (_, None):
reveal_type(m) # N: Revealed type is "Union[Tuple[None, Union[builtins.int, None]], Tuple[builtins.int, None]]"
case a, b:
reveal_type(a) # N: Revealed type is "builtins.int"
reveal_type(b) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]


-- Mapping Pattern --

[case testMatchMappingPatternCaptures]
Expand Down

0 comments on commit 8454267

Please sign in to comment.