Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of generic functions in partial plugin #17925

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
Type,
TypeOfAny,
TypeVarType,
UnboundType,
UnionType,
get_proper_type,
Expand Down Expand Up @@ -164,21 +166,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
ctx.api.type_context[-1] = None
wrapped_return = False

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
],
ret_type=ret_type,
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

# Flatten actual to formal mapping, since this is what check_call() expects.
actual_args = []
actual_arg_kinds = []
Expand All @@ -199,6 +186,43 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
actual_arg_names.append(ctx.arg_names[i][j])
actual_types.append(ctx.arg_types[i][j])

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

# We need to remove any type variables that appear only in formals that have
# no actuals, to avoid eagerly binding them in check_call() below.
can_infer_ids = set()
for i, arg_type in enumerate(fn_type.arg_types):
if not formal_to_actual[i]:
continue
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
],
ret_type=ret_type,
variables=[
tv
for tv in fn_type.variables
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
],
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
callee=ctx.args[0][0],
Expand Down Expand Up @@ -231,14 +255,6 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
return ctx.default_return_type
bound = bound.copy_modified(ret_type=ret_type.args[0])

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

partial_kinds = []
partial_types = []
partial_names = []
Expand Down
31 changes: 29 additions & 2 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ def bar(f: S) -> S:
return f
[builtins fixtures/primitives.pyi]


[case testFunctoolsPartialAbstractType]
# flags: --python-version 3.9
from abc import ABC, abstractmethod
Expand All @@ -597,7 +596,6 @@ def f2() -> None:
partial_cls() # E: Cannot instantiate abstract class "A" with abstract attribute "method"
[builtins fixtures/tuple.pyi]


[case testFunctoolsPartialSelfType]
from functools import partial
from typing_extensions import Self
Expand All @@ -610,3 +608,32 @@ class A:
factory = partial(cls, ts=0)
return factory(msg=msg)
[builtins fixtures/tuple.pyi]

[case testFunctoolsPartialTypeVarValues]
from functools import partial
from typing import TypeVar

T = TypeVar("T", int, str)

def f(x: int, y: T) -> T:
return y

def g(x: T, y: int) -> T:
return x

def h(x: T, y: T) -> T:
return x

fp = partial(f, 1)
reveal_type(fp(1)) # N: Revealed type is "builtins.int"
reveal_type(fp("a")) # N: Revealed type is "builtins.str"
fp(object()) # E: Value of type variable "T" of "f" cannot be "object"

gp = partial(g, 1)
reveal_type(gp(1)) # N: Revealed type is "builtins.int"
gp("a") # E: Argument 1 to "g" has incompatible type "str"; expected "int"

hp = partial(h, 1)
reveal_type(hp(1)) # N: Revealed type is "builtins.int"
hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int"
[builtins fixtures/tuple.pyi]
Loading