Skip to content

Commit

Permalink
Fix union callees with functools.partial (#17903)
Browse files Browse the repository at this point in the history
Fixes #17741.
  • Loading branch information
JukkaL authored Oct 9, 2024
1 parent 706a546 commit eca206d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
15 changes: 14 additions & 1 deletion mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Type,
TypeOfAny,
UnboundType,
UnionType,
get_proper_type,
)

Expand Down Expand Up @@ -130,7 +131,19 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
# TODO: handle overloads, just fall back to whatever the non-plugin code does
return ctx.default_return_type
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type)
return handle_partial_with_callee(ctx, callee=ctx.arg_types[0][0])


def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type:
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type

if isinstance(callee_proper := get_proper_type(callee), UnionType):
return UnionType.make_union(
[handle_partial_with_callee(ctx, item) for item in callee_proper.items]
)

fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type)
if fn_type is None:
return ctx.default_return_type

Expand Down
21 changes: 19 additions & 2 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -346,15 +346,32 @@ fn1: Union[Callable[[int], int], Callable[[int], int]]
reveal_type(functools.partial(fn1, 2)()) # N: Revealed type is "builtins.int"

fn2: Union[Callable[[int], int], Callable[[int], str]]
reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "builtins.object"
reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "Union[builtins.int, builtins.str]"

fn3: Union[Callable[[int], int], str]
reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \
# E: "Union[Callable[[int], int], str]" not callable \
# N: Revealed type is "builtins.int" \
# E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]"
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialUnionOfTypeAndCallable]
import functools
from typing import Callable, Union, Type
from typing_extensions import TypeAlias

class FooBar:
def __init__(self, arg1: str) -> None:
pass

def f1(t: Union[Type[FooBar], Callable[..., 'FooBar']]) -> None:
val = functools.partial(t)

FooBarFunc: TypeAlias = Callable[..., 'FooBar']

def f2(t: Union[Type[FooBar], FooBarFunc]) -> None:
val = functools.partial(t)
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialExplicitType]
from functools import partial
from typing import Type, TypeVar, Callable
Expand Down

0 comments on commit eca206d

Please sign in to comment.