Skip to content

Commit

Permalink
Add support for functools.partial
Browse files Browse the repository at this point in the history
Fixes #1484

This is currently the most popular mypy issue that does not need a PEP.
I'm sure there's stuff missing, but this should handle most cases.
  • Loading branch information
hauntsaninja committed Feb 23, 2024
1 parent 790e8a7 commit 5b56460
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 24 deletions.
34 changes: 17 additions & 17 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,14 +1216,14 @@ def apply_function_plugin(
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)
else:
Expand All @@ -1233,15 +1233,15 @@ def apply_function_plugin(
object_type = get_proper_type(object_type)
return method_callback(
MethodContext(
object_type,
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
type=object_type,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)

Expand Down
8 changes: 8 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
elif fullname == "functools.partial":
import mypy.plugins.functools

return mypy.plugins.functools.partial_new_callback

return None

Expand Down Expand Up @@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
import mypy.plugins.functools

return mypy.plugins.functools.partial_call_callback
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
Expand Down
132 changes: 130 additions & 2 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@

from typing import Final, NamedTuple

import mypy.checker
import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
from mypy.types import (
AnyType,
CallableType,
Instance,
Type,
TypeOfAny,
UnboundType,
UninhabitedType,
get_proper_type,
)

functools_total_ordering_makers: Final = {"functools.total_ordering"}

Expand Down Expand Up @@ -102,3 +113,120 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
comparison_methods[name] = None

return comparison_methods


def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.partial"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
return ctx.default_return_type
if len(ctx.arg_types[0]) != 1:
return ctx.default_return_type

fn_type = get_proper_type(ctx.arg_types[0][0])
if not isinstance(fn_type, CallableType):
return ctx.default_return_type

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
]
)

actual_args = [a for param in ctx.args[1:] for a in param]
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]

_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
return ctx.default_return_type

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

partial_kinds = []
partial_types = []
partial_names = []
# We need to fully apply any positional arguments (they cannot be respecified)
# However, keyword arguments can be respecified, so just give them a default
for i, actuals in enumerate(formal_to_actual):
arg_type = bound.arg_types[i]
if isinstance(get_proper_type(arg_type), UninhabitedType):
arg_type = fn_type.arg_types[i] # bit of a hack

if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
elif actuals:
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
continue
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED:
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])

ret_type = bound.ret_type
if isinstance(get_proper_type(ret_type), UninhabitedType):
ret_type = fn_type.ret_type # same kind of hack as above

partially_applied = fn_type.copy_modified(
arg_types=partial_types,
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
)

ret = ctx.api.named_generic_type("functools.partial", [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
return ret


def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Infer a more precise return type for functools.partial.__call__."""
if (
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
or not isinstance(ctx.type, Instance)
or ctx.type.type.fullname != "functools.partial"
or not ctx.type.extra_attrs
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
):
return ctx.default_return_type

partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

args = [a for param in ctx.args for a in param]
arg_kinds = [a for param in ctx.arg_kinds for a in param]
arg_names = [a for param in ctx.arg_names for a in param]

result = ctx.api.expr_checker.check_call(
callee=partial_type,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=ctx.context,
)
return result[0]
9 changes: 5 additions & 4 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,13 +1495,14 @@ def copy_modified(
last_known_value: Bogus[LiteralType | None] = _dummy,
) -> Instance:
new = Instance(
self.type,
args if args is not _dummy else self.args,
self.line,
self.column,
typ=self.type,
args=args if args is not _dummy else self.args,
line=self.line,
column=self.column,
last_known_value=(
last_known_value if last_known_value is not _dummy else self.last_known_value
),
extra_attrs=self.extra_attrs,
)
# We intentionally don't copy the extra_attrs here, so they will be erased.
new.can_be_true = self.can_be_true
Expand Down
86 changes: 86 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,89 @@ def f(d: D[C]) -> None:

d: D[int] # E: Type argument "int" of "D" must be a subtype of "C"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialBasic]
from typing import Callable
import functools

def foo(a: int, b: str, c: int = 5) -> int: ... # N: "foo" defined here

p1 = functools.partial(foo)
p1(1, "a", 3) # OK
p1(1, "a", c=3) # OK
p1(1, b="a", c=3) # OK

def takes_callable_int(f: Callable[..., int]) -> None: ...
def takes_callable_str(f: Callable[..., str]) -> None: ...
takes_callable_int(p1)
takes_callable_str(p1) # E: Argument 1 to "takes_callable_str" has incompatible type "partial[int]"; expected "Callable[..., str]" \
# N: "partial[int].__call__" has type "Callable[[VarArg(Any), KwArg(Any)], int]"

p2 = functools.partial(foo, 1)
p2("a") # OK
p2("a", 3) # OK
p2("a", c=3) # OK
p2(1, 3) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
p2(1, "a", 3) # E: Too many arguments for "foo" \
# E: Argument 1 to "foo" has incompatible type "int"; expected "str" \
# E: Argument 2 to "foo" has incompatible type "str"; expected "int"
p2(a=1, b="a", c=3) # E: Unexpected keyword argument "a" for "foo"

p3 = functools.partial(foo, b="a")
p3(1) # OK
p3(1, c=3) # OK
p3(a=1) # OK
p3(1, b="a", c=3) # OK, keywords can be clobbered
p3(1, 3) # E: Too many positional arguments for "foo" \
# E: Argument 2 to "foo" has incompatible type "int"; expected "str"

functools.partial(foo, "a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"
functools.partial(foo, b=1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
functools.partial(1) # E: Argument 1 to "partial" has incompatible type "int"; expected "Callable[..., Never]"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialStar]
import functools

def foo(a: int, b: str, *args: int, d: str, **kwargs: int) -> int: ...

p1 = functools.partial(foo, 1, d="a", x=9)
p1("a", 2, 3, 4) # OK
p1("a", 2, 3, 4, d="a") # OK
p1("a", 2, 3, 4, "a") # E: Argument 5 to "foo" has incompatible type "str"; expected "int"
p1("a", 2, 3, 4, x="a") # E: Argument "x" to "foo" has incompatible type "str"; expected "int"

p2 = functools.partial(foo, 1, "a")
p2(2, 3, 4, d="a") # OK
p2("a") # E: Missing named argument "d" for "foo" \
# E: Argument 1 to "foo" has incompatible type "str"; expected "int"
p2(2, 3, 4) # E: Missing named argument "d" for "foo"

functools.partial(foo, 1, "a", "b", "c", d="a") # E: Argument 3 to "foo" has incompatible type "str"; expected "int" \
# E: Argument 4 to "foo" has incompatible type "str"; expected "int"

[builtins fixtures/dict.pyi]

[case testFunctoolsPartialGeneric]
from typing import TypeVar
import functools

T = TypeVar("T")
U = TypeVar("U")

def foo(a: T, b: T) -> T: ...

p1 = functools.partial(foo, 1)
reveal_type(p1(2)) # N: Revealed type is "builtins.int"
p1("a") # E: Argument 1 to "foo" has incompatible type "str"; expected "int"

p2 = functools.partial(foo, "a")
p2(1) # E: Argument 1 to "foo" has incompatible type "int"; expected "str"
reveal_type(p2("a")) # N: Revealed type is "builtins.str"

def bar(a: T, b: U) -> U: ...

p3 = functools.partial(bar, 1)
reveal_type(p3(2)) # N: Revealed type is "builtins.int"
reveal_type(p3("a")) # N: Revealed type is "builtins.str"
[builtins fixtures/dict.pyi]
6 changes: 5 additions & 1 deletion test-data/unit/lib-stub/functools.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, TypeVar, Callable, Any, Mapping, overload
from typing import Generic, TypeVar, Callable, Any, Mapping, Self, overload

_T = TypeVar("_T")

Expand Down Expand Up @@ -33,3 +33,7 @@ class cached_property(Generic[_T]):
def __get__(self, instance: object, owner: type[Any] | None = ...) -> _T: ...
def __set_name__(self, owner: type[Any], name: str) -> None: ...
def __class_getitem__(cls, item: Any) -> Any: ...

class partial(Generic[_T]):
def __new__(cls, __func: Callable[..., _T], *args: Any, **kwargs: Any) -> Self: ...
def __call__(__self, *args: Any, **kwargs: Any) -> _T: ...

0 comments on commit 5b56460

Please sign in to comment.