Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix daemon crash on malformed NamedTuple (#14119) #1

Merged
merged 10 commits into from
Nov 21, 2022
Merged
10 changes: 5 additions & 5 deletions docs/source/generics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ Before parameter specifications, here's how one might have annotated the decorat

.. code-block:: python

from typing import Callable, TypeVar
from typing import Any, Callable, TypeVar, cast

F = TypeVar('F', bound=Callable[..., Any])

Expand All @@ -650,8 +650,8 @@ and that would enable the following type checks:

.. code-block:: python

reveal_type(a) # str
add_forty_two('x') # Type check error: incompatible type "str"; expected "int"
reveal_type(a) # Revealed type is "builtins.int"
add_forty_two('x') # Argument 1 to "add_forty_two" has incompatible type "str"; expected "int"


Note that the ``wrapper()`` function is not type-checked. Wrapper
Expand Down Expand Up @@ -724,7 +724,7 @@ achieved by combining with :py:func:`@overload <typing.overload>`:

.. code-block:: python

from typing import Any, Callable, TypeVar, overload
from typing import Any, Callable, Optional, TypeVar, overload

F = TypeVar('F', bound=Callable[..., Any])

Expand All @@ -736,7 +736,7 @@ achieved by combining with :py:func:`@overload <typing.overload>`:
def atomic(*, savepoint: bool = True) -> Callable[[F], F]: ...

# Implementation
def atomic(__func: Callable[..., Any] = None, *, savepoint: bool = True):
def atomic(__func: Optional[Callable[..., Any]] = None, *, savepoint: bool = True):
def decorator(func: Callable[..., Any]):
... # Code goes here
if __func is not None:
Expand Down
1 change: 1 addition & 0 deletions misc/sync-typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def main() -> None:
commits_to_cherry_pick = [
"780534b13722b7b0422178c049a1cbbf4ea4255b", # LiteralString reverts
"5319fa34a8004c1568bb6f032a07b8b14cc95bed", # sum reverts
"0062994228fb62975c6cef4d2c80d00c7aa1c545", # ctypes reverts
]
for commit in commits_to_cherry_pick:
subprocess.run(["git", "cherry-pick", commit], check=True)
Expand Down
1 change: 1 addition & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ class FreezeTypeVarsVisitor(TypeTraverserVisitor):
def visit_callable_type(self, t: CallableType) -> None:
for v in t.variables:
v.id.meta_level = 0
super().visit_callable_type(t)


def lookup_member_var_or_accessor(info: TypeInfo, name: str, is_lvalue: bool) -> SymbolNode | None:
Expand Down
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ class ClassDef(Statement):
"analyzed",
"has_incompatible_baseclass",
"deco_line",
"removed_statements",
)

__match_args__ = ("name", "defs")
Expand All @@ -1086,6 +1087,8 @@ class ClassDef(Statement):
keywords: dict[str, Expression]
analyzed: Expression | None
has_incompatible_baseclass: bool
# Used by special forms like NamedTuple and TypedDict to store invalid statements
removed_statements: list[Statement]

def __init__(
self,
Expand All @@ -1111,6 +1114,7 @@ def __init__(
self.has_incompatible_baseclass = False
# Used for error reporting (to keep backwad compatibility with pre-3.8)
self.deco_line: int | None = None
self.removed_statements = []

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_class_def(self)
Expand Down
8 changes: 7 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,13 @@ def visit_decorator(self, dec: Decorator) -> None:
dec.var.is_classmethod = True
self.check_decorated_function_is_method("classmethod", dec)
elif refers_to_fullname(
d, ("builtins.property", "abc.abstractproperty", "functools.cached_property")
d,
(
"builtins.property",
"abc.abstractproperty",
"functools.cached_property",
"enum.property",
),
):
removed.append(i)
dec.func.is_property = True
Expand Down
19 changes: 14 additions & 5 deletions mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
NameExpr,
PassStmt,
RefExpr,
Statement,
StrExpr,
SymbolTable,
SymbolTableNode,
Expand Down Expand Up @@ -111,7 +112,7 @@ def analyze_namedtuple_classdef(
if result is None:
# This is a valid named tuple, but some types are incomplete.
return True, None
items, types, default_items = result
items, types, default_items, statements = result
if is_func_scope and "@" not in defn.name:
defn.name += "@" + str(defn.line)
existing_info = None
Expand All @@ -123,31 +124,35 @@ def analyze_namedtuple_classdef(
defn.analyzed = NamedTupleExpr(info, is_typed=True)
defn.analyzed.line = defn.line
defn.analyzed.column = defn.column
defn.defs.body = statements
# All done: this is a valid named tuple with all types known.
return True, info
# This can't be a valid named tuple.
return False, None

def check_namedtuple_classdef(
self, defn: ClassDef, is_stub_file: bool
) -> tuple[list[str], list[Type], dict[str, Expression]] | None:
) -> tuple[list[str], list[Type], dict[str, Expression], list[Statement]] | None:
"""Parse and validate fields in named tuple class definition.

Return a three tuple:
Return a four tuple:
* field names
* field types
* field default values
* valid statements
or None, if any of the types are not ready.
"""
if self.options.python_version < (3, 6) and not is_stub_file:
self.fail("NamedTuple class syntax is only supported in Python 3.6", defn)
return [], [], {}
return [], [], {}, []
if len(defn.base_type_exprs) > 1:
self.fail("NamedTuple should be a single base", defn)
items: list[str] = []
types: list[Type] = []
default_items: dict[str, Expression] = {}
statements: list[Statement] = []
for stmt in defn.defs.body:
statements.append(stmt)
if not isinstance(stmt, AssignmentStmt):
# Still allow pass or ... (for empty namedtuples).
if isinstance(stmt, PassStmt) or (
Expand All @@ -160,9 +165,13 @@ def check_namedtuple_classdef(
# And docstrings.
if isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr):
continue
statements.pop()
defn.removed_statements.append(stmt)
self.fail(NAMEDTUP_CLASS_ERROR, stmt)
elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr):
# An assignment, but an invalid one.
statements.pop()
defn.removed_statements.append(stmt)
self.fail(NAMEDTUP_CLASS_ERROR, stmt)
else:
# Append name and type in this case...
Expand Down Expand Up @@ -199,7 +208,7 @@ def check_namedtuple_classdef(
)
else:
default_items[name] = stmt.rvalue
return items, types, default_items
return items, types, default_items, statements

def check_namedtuple(
self, node: Expression, var_name: str | None, is_func_scope: bool
Expand Down
2 changes: 2 additions & 0 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,11 @@ def analyze_typeddict_classdef_fields(
):
statements.append(stmt)
else:
defn.removed_statements.append(stmt)
self.fail(TPDICT_CLASS_ERROR, stmt)
elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr):
# An assignment, but an invalid one.
defn.removed_statements.append(stmt)
self.fail(TPDICT_CLASS_ERROR, stmt)
else:
name = stmt.lvalues[0].name
Expand Down
2 changes: 2 additions & 0 deletions mypy/server/aststrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def visit_class_def(self, node: ClassDef) -> None:
]
with self.enter_class(node.info):
super().visit_class_def(node)
node.defs.body.extend(node.removed_statements)
node.removed_statements = []
TypeState.reset_subtype_caches_for(node.info)
# Kill the TypeInfo, since there is none before semantic analysis.
node.info = CLASSDEF_NO_INFO
Expand Down
7 changes: 7 additions & 0 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UninhabitedType,
UnionType,
get_proper_type,
has_recursive_types,
)


Expand Down Expand Up @@ -157,6 +158,12 @@ def test_type_alias_expand_all(self) -> None:
[self.fx.a, self.fx.a], Instance(self.fx.std_tuplei, [self.fx.a])
)

def test_recursive_nested_in_non_recursive(self) -> None:
A, _ = self.fx.def_alias_1(self.fx.a)
NA = self.fx.non_rec_alias(Instance(self.fx.gi, [UnboundType("T")]), ["T"], [A])
assert not NA.is_recursive
assert has_recursive_types(NA)

def test_indirection_no_infinite_recursion(self) -> None:
A, _ = self.fx.def_alias_1(self.fx.a)
visitor = TypeIndirectionVisitor()
Expand Down
10 changes: 7 additions & 3 deletions mypy/test/typefixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,13 @@ def def_alias_2(self, base: Instance) -> tuple[TypeAliasType, Type]:
A.alias = AN
return A, target

def non_rec_alias(self, target: Type) -> TypeAliasType:
AN = TypeAlias(target, "__main__.A", -1, -1)
return TypeAliasType(AN, [])
def non_rec_alias(
self, target: Type, alias_tvars: list[str] | None = None, args: list[Type] | None = None
) -> TypeAliasType:
AN = TypeAlias(target, "__main__.A", -1, -1, alias_tvars=alias_tvars)
if args is None:
args = []
return TypeAliasType(AN, args)


class InterfaceTypeFixture(TypeFixture):
Expand Down
24 changes: 8 additions & 16 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,24 +404,16 @@ def visit_placeholder_type(self, t: PlaceholderType) -> T:
return self.query_types(t.args)

def visit_type_alias_type(self, t: TypeAliasType) -> T:
# Skip type aliases already visited types to avoid infinite recursion.
# TODO: Ideally we should fire subvisitors here (or use caching) if we care
# about duplicates.
if t in self.seen_aliases:
return self.strategy([])
self.seen_aliases.add(t)
if self.skip_alias_target:
return self.query_types(t.args)
return get_proper_type(t).accept(self)

def query_types(self, types: Iterable[Type]) -> T:
"""Perform a query for a list of types.

Use the strategy to combine the results.
Skip type aliases already visited types to avoid infinite recursion.
"""
res: list[T] = []
for t in types:
if isinstance(t, TypeAliasType):
# Avoid infinite recursion for recursive type aliases.
# TODO: Ideally we should fire subvisitors here (or use caching) if we care
# about duplicates.
if t in self.seen_aliases:
continue
self.seen_aliases.add(t)
res.append(t.accept(self))
return self.strategy(res)
"""Perform a query for a list of types using the strategy to combine the results."""
return self.strategy([t.accept(self) for t in types])
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
if fullname == "builtins.None":
return NoneType()
elif fullname == "typing.Any" or fullname == "builtins.Any":
return AnyType(TypeOfAny.explicit)
return AnyType(TypeOfAny.explicit, line=t.line, column=t.column)
elif fullname in FINAL_TYPE_NAMES:
self.fail(
"Final can be only used as an outermost qualifier in a variable annotation",
Expand Down
26 changes: 19 additions & 7 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,42 @@ def _expand_once(self) -> Type:
self.alias.target, self.alias.alias_tvars, self.args, self.line, self.column
)

def _partial_expansion(self) -> tuple[ProperType, bool]:
def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]:
# Private method mostly for debugging and testing.
unroller = UnrollAliasVisitor(set())
unrolled = self.accept(unroller)
if nothing_args:
alias = self.copy_modified(args=[UninhabitedType()] * len(self.args))
else:
alias = self
unrolled = alias.accept(unroller)
assert isinstance(unrolled, ProperType)
return unrolled, unroller.recursed

def expand_all_if_possible(self) -> ProperType | None:
def expand_all_if_possible(self, nothing_args: bool = False) -> ProperType | None:
"""Attempt a full expansion of the type alias (including nested aliases).

If the expansion is not possible, i.e. the alias is (mutually-)recursive,
return None.
return None. If nothing_args is True, replace all type arguments with an
UninhabitedType() (used to detect recursively defined aliases).
"""
unrolled, recursed = self._partial_expansion()
unrolled, recursed = self._partial_expansion(nothing_args=nothing_args)
if recursed:
return None
return unrolled

@property
def is_recursive(self) -> bool:
"""Whether this type alias is recursive.

Note this doesn't check generic alias arguments, but only if this alias
*definition* is recursive. The property value thus can be cached on the
underlying TypeAlias node. If you want to include all nested types, use
has_recursive_types() function.
"""
assert self.alias is not None, "Unfixed type alias"
is_recursive = self.alias._is_recursive
if is_recursive is None:
is_recursive = self.expand_all_if_possible() is None
is_recursive = self.expand_all_if_possible(nothing_args=True) is None
# We cache the value on the underlying TypeAlias node as an optimization,
# since the value is the same for all instances of the same alias.
self.alias._is_recursive = is_recursive
Expand Down Expand Up @@ -3259,7 +3271,7 @@ def __init__(self) -> None:
super().__init__(any)

def visit_type_alias_type(self, t: TypeAliasType) -> bool:
return t.is_recursive
return t.is_recursive or self.query_types(t.args)


def has_recursive_types(typ: Type) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion mypy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,11 @@ def parse_gray_color(cup: bytes) -> str:


def should_force_color() -> bool:
return bool(int(os.getenv("MYPY_FORCE_COLOR", os.getenv("FORCE_COLOR", "0"))))
env_var = os.getenv("MYPY_FORCE_COLOR", os.getenv("FORCE_COLOR", "0"))
try:
return bool(int(env_var))
except ValueError:
return bool(env_var)


class FancyFormatter:
Expand Down
2 changes: 0 additions & 2 deletions test-data/unit/check-class-namedtuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,6 @@ class X(typing.NamedTuple):
[out]
main:6: error: Invalid statement in NamedTuple definition; expected "field_name: field_type [= default]"
main:7: error: Invalid statement in NamedTuple definition; expected "field_name: field_type [= default]"
main:7: error: Type cannot be declared in assignment to non-self attribute
main:7: error: "int" has no attribute "x"
main:9: error: Non-default NamedTuple fields cannot follow default fields

[builtins fixtures/list.pyi]
Expand Down
14 changes: 14 additions & 0 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -6334,3 +6334,17 @@ reveal_type(D().meth)
[out2]
tmp/m.py:4: note: Revealed type is "def [Self <: lib.C] (self: Self`0, other: Self`0) -> Self`0"
tmp/m.py:5: note: Revealed type is "def (other: m.D) -> m.D"

[case testIncrementalNestedGenericCallableCrash]
from typing import TypeVar, Callable

T = TypeVar("T")

class B:
def foo(self) -> Callable[[T], T]: ...

class C(B):
def __init__(self) -> None:
self.x = self.foo()
[out]
[out2]
Loading