From 662b1bd354e1b1813ea4566c673cb06c074633f8 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 9 Oct 2024 13:04:56 +0100 Subject: [PATCH 1/2] Fix union callees with functools.partial Fixes #17741. --- mypy/plugins/functools.py | 16 +++++++++++++++- test-data/unit/check-functools.test | 21 +++++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index 6650af637519..b415805c1ec5 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -18,6 +18,7 @@ Type, TypeOfAny, UnboundType, + UnionType, get_proper_type, ) @@ -130,7 +131,20 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): # TODO: handle overloads, just fall back to whatever the non-plugin code does return ctx.default_return_type - fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type) + callee = ctx.arg_types[0][0] + if isinstance(callee_proper := get_proper_type(callee), UnionType): + return UnionType.make_union( + [handle_partial_with_callee(ctx, item) for item in callee_proper.items] + ) + else: + return handle_partial_with_callee(ctx, callee) + + +def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type: + if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals + return ctx.default_return_type + + fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type) if fn_type is None: return ctx.default_return_type diff --git a/test-data/unit/check-functools.test b/test-data/unit/check-functools.test index 50de3789ebd2..bee30931a92b 100644 --- a/test-data/unit/check-functools.test +++ b/test-data/unit/check-functools.test @@ -346,15 +346,32 @@ fn1: Union[Callable[[int], int], Callable[[int], int]] reveal_type(functools.partial(fn1, 2)()) # N: Revealed type is "builtins.int" fn2: Union[Callable[[int], int], Callable[[int], str]] -reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "builtins.object" +reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "Union[builtins.int, builtins.str]" fn3: Union[Callable[[int], int], str] reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \ - # E: "Union[Callable[[int], int], str]" not callable \ # N: Revealed type is "builtins.int" \ # E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]" [builtins fixtures/tuple.pyi] +[case testFunctoolsPartialUnionOfTypeAndCallable] +import functools +from typing import Callable, Union, Type +from typing_extensions import TypeAlias + +class FooBar: + def __init__(self, arg1: str) -> None: + pass + +def f1(t: Union[Type[FooBar], Callable[..., 'FooBar']]) -> None: + val = functools.partial(t) + +FooBarFunc: TypeAlias = Callable[..., 'FooBar'] + +def f2(t: Union[Type[FooBar], FooBarFunc]) -> None: + val = functools.partial(t) +[builtins fixtures/tuple.pyi] + [case testFunctoolsPartialExplicitType] from functools import partial from typing import Type, TypeVar, Callable From d4afd6df2bb30685e0f395b98c09af8f27eb1159 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 9 Oct 2024 13:24:37 +0100 Subject: [PATCH 2/2] Fix nested unions --- mypy/plugins/functools.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index b415805c1ec5..f09ea88f7162 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -131,19 +131,18 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded): # TODO: handle overloads, just fall back to whatever the non-plugin code does return ctx.default_return_type - callee = ctx.arg_types[0][0] - if isinstance(callee_proper := get_proper_type(callee), UnionType): - return UnionType.make_union( - [handle_partial_with_callee(ctx, item) for item in callee_proper.items] - ) - else: - return handle_partial_with_callee(ctx, callee) + return handle_partial_with_callee(ctx, callee=ctx.arg_types[0][0]) def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type: if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals return ctx.default_return_type + if isinstance(callee_proper := get_proper_type(callee), UnionType): + return UnionType.make_union( + [handle_partial_with_callee(ctx, item) for item in callee_proper.items] + ) + fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type) if fn_type is None: return ctx.default_return_type