diff --git a/ChangeLog b/ChangeLog index 01f7d1bbe9..451c64575d 100644 --- a/ChangeLog +++ b/ChangeLog @@ -363,6 +363,15 @@ Release date: 2022-07-09 Refs PyCQA/pylint#7109 +* Support "is None" constraints from if statements during inference. + + Ref #791 + Ref PyCQA/pylint#157 + Ref PyCQA/pylint#1472 + Ref PyCQA/pylint#2016 + Ref PyCQA/pylint#2631 + Ref PyCQA/pylint#2880 + What's New in astroid 2.11.7? ============================= Release date: 2022-07-09 diff --git a/astroid/bases.py b/astroid/bases.py index 2c0478557c..bf99ddce7a 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -11,7 +11,7 @@ import collections.abc import sys from collections.abc import Sequence -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from astroid import decorators, nodes from astroid.const import PY310_PLUS @@ -35,6 +35,9 @@ else: from typing_extensions import Literal +if TYPE_CHECKING: + from astroid.constraint import Constraint + objectmodel = lazy_import("interpreter.objectmodel") helpers = lazy_import("helpers") manager = lazy_import("manager") @@ -146,11 +149,14 @@ def _infer_stmts( ) -> collections.abc.Generator[InferenceResult, None, None]: """Return an iterator on statements inferred by each statement in *stmts*.""" inferred = False + constraint_failed = False if context is not None: name = context.lookupname context = context.clone() + constraints = context.constraints.get(name, {}) else: name = None + constraints = {} context = InferenceContext() for stmt in stmts: @@ -161,16 +167,26 @@ def _infer_stmts( # 'context' is always InferenceContext and Instances get '_infer_name' from ClassDef context.lookupname = stmt._infer_name(frame, name) # type: ignore[union-attr] try: + stmt_constraints: set[Constraint] = set() + for constraint_stmt, potential_constraints in constraints.items(): + if not constraint_stmt.parent_of(stmt): + stmt_constraints.update(potential_constraints) # Mypy doesn't recognize that 'stmt' can't be Uninferable for inf in stmt.infer(context=context): # type: ignore[union-attr] - yield inf - inferred = True + if all(constraint.satisfied_by(inf) for constraint in stmt_constraints): + yield inf + inferred = True + else: + constraint_failed = True except NameInferenceError: continue except InferenceError: yield Uninferable inferred = True - if not inferred: + + if not inferred and constraint_failed: + yield Uninferable + elif not inferred: raise InferenceError( "Inference failed for all members of {stmts!r}.", stmts=stmts, diff --git a/astroid/constraint.py b/astroid/constraint.py new file mode 100644 index 0000000000..deed9ac52b --- /dev/null +++ b/astroid/constraint.py @@ -0,0 +1,137 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt + +"""Classes representing different types of constraints on inference values.""" +from __future__ import annotations + +import sys +from abc import ABC, abstractmethod +from collections.abc import Iterator +from typing import Union + +from astroid import bases, nodes, util +from astroid.typing import InferenceResult + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] + + +class Constraint(ABC): + """Represents a single constraint on a variable.""" + + def __init__(self, node: nodes.NodeNG, negate: bool) -> None: + self.node = node + """The node that this constraint applies to.""" + self.negate = negate + """True if this constraint is negated. E.g., "is not" instead of "is".""" + + @classmethod + @abstractmethod + def match( + cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: + """Return a new constraint for node matched from expr, if expr matches + the constraint pattern. + + If negate is True, negate the constraint. + """ + + @abstractmethod + def satisfied_by(self, inferred: InferenceResult) -> bool: + """Return True if this constraint is satisfied by the given inferred value.""" + + +class NoneConstraint(Constraint): + """Represents an "is None" or "is not None" constraint.""" + + CONST_NONE: nodes.Const = nodes.Const(None) + + @classmethod + def match( + cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: + """Return a new constraint for node matched from expr, if expr matches + the constraint pattern. + + Negate the constraint based on the value of negate. + """ + if isinstance(expr, nodes.Compare) and len(expr.ops) == 1: + left = expr.left + op, right = expr.ops[0] + if op in {"is", "is not"} and ( + _matches(left, node) and _matches(right, cls.CONST_NONE) + ): + negate = (op == "is" and negate) or (op == "is not" and not negate) + return cls(node=node, negate=negate) + + return None + + def satisfied_by(self, inferred: InferenceResult) -> bool: + """Return True if this constraint is satisfied by the given inferred value.""" + # Assume true if uninferable + if inferred is util.Uninferable: + return True + + # Return the XOR of self.negate and matches(inferred, self.CONST_NONE) + return self.negate ^ _matches(inferred, self.CONST_NONE) + + +def get_constraints( + expr: _NameNodes, frame: nodes.LocalsDictNodeNG +) -> dict[nodes.If, set[Constraint]]: + """Returns the constraints for the given expression. + + The returned dictionary maps the node where the constraint was generated to the + corresponding constraint(s). + + Constraints are computed statically by analysing the code surrounding expr. + Currently this only supports constraints generated from if conditions. + """ + current_node: nodes.NodeNG | None = expr + constraints_mapping: dict[nodes.If, set[Constraint]] = {} + while current_node is not None and current_node is not frame: + parent = current_node.parent + if isinstance(parent, nodes.If): + branch, _ = parent.locate_child(current_node) + constraints: set[Constraint] | None = None + if branch == "body": + constraints = set(_match_constraint(expr, parent.test)) + elif branch == "orelse": + constraints = set(_match_constraint(expr, parent.test, invert=True)) + + if constraints: + constraints_mapping[parent] = constraints + current_node = parent + + return constraints_mapping + + +ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,)) +"""All supported constraint types.""" + + +def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool: + """Returns True if the two nodes match.""" + if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): + return node1.name == node2.name + if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute): + return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr) + if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const): + return node1.value == node2.value + + return False + + +def _match_constraint( + node: _NameNodes, expr: nodes.NodeNG, invert: bool = False +) -> Iterator[Constraint]: + """Yields all constraint patterns for node that match.""" + for constraint_cls in ALL_CONSTRAINT_CLASSES: + constraint = constraint_cls.match(node, expr, invert) + if constraint: + yield constraint diff --git a/astroid/context.py b/astroid/context.py index d7f74778bd..221fd84fbe 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple if TYPE_CHECKING: + from astroid import constraint, nodes from astroid.nodes.node_classes import Keyword, NodeNG _InferenceCache = Dict[ @@ -37,6 +38,7 @@ class InferenceContext: "callcontext", "boundnode", "extra_context", + "constraints", "_nodes_inferred", ) @@ -85,6 +87,9 @@ def __init__( for call arguments """ + self.constraints: dict[str, dict[nodes.If, set[constraint.Constraint]]] = {} + """The constraints on nodes.""" + @property def nodes_inferred(self) -> int: """ @@ -134,6 +139,7 @@ def clone(self) -> InferenceContext: clone.callcontext = self.callcontext clone.boundnode = self.boundnode clone.extra_context = self.extra_context + clone.constraints = self.constraints.copy() return clone @contextlib.contextmanager diff --git a/astroid/inference.py b/astroid/inference.py index cb79e823e9..b3a0c4a1d0 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -15,7 +15,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union -from astroid import bases, decorators, helpers, nodes, protocols, util +from astroid import bases, constraint, decorators, helpers, nodes, protocols, util from astroid.context import ( CallContext, InferenceContext, @@ -242,6 +242,8 @@ def infer_name( ) context = copy_context(context) context.lookupname = self.name + context.constraints[self.name] = constraint.get_constraints(self, frame) + return bases._infer_stmts(stmts, context, frame) @@ -362,6 +364,11 @@ def infer_attribute( old_boundnode = context.boundnode try: context.boundnode = owner + if isinstance(owner, (nodes.ClassDef, bases.Instance)): + frame = owner if isinstance(owner, nodes.ClassDef) else owner._proxied + context.constraints[self.attrname] = constraint.get_constraints( + self, frame=frame + ) yield from owner.igetattr(self.attrname, context) except ( AttributeInferenceError, diff --git a/tests/unittest_constraint.py b/tests/unittest_constraint.py new file mode 100644 index 0000000000..cba506fdcb --- /dev/null +++ b/tests/unittest_constraint.py @@ -0,0 +1,590 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt + +"""Tests for inference involving constraints""" +from __future__ import annotations + +import pytest + +from astroid import builder, nodes +from astroid.util import Uninferable + + +def common_params(node: str) -> pytest.MarkDecorator: + return pytest.mark.parametrize( + ("condition", "satisfy_val", "fail_val"), + ( + (f"{node} is None", None, 3), + (f"{node} is not None", 3, None), + ), + ) + + +@common_params(node="x") +def test_if_single_statement( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test constraint for a variable that is used in the first statement of an if body.""" + node1, node2 = builder.extract_node( + f""" + def f1(x = {fail_val}): + if {condition}: # Filters out default value + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + if {condition}: # Does not filter out default value + return ( + x #@ + ) + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_multiple_statements( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test constraint for a variable that is used in an if body with multiple statements.""" + node1, node2 = builder.extract_node( + f""" + def f1(x = {fail_val}): + if {condition}: # Filters out default value + print(x) + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + if {condition}: # Does not filter out default value + print(x) + return ( + x #@ + ) + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_irrelevant_condition( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint for a different variable doesn't apply.""" + nodes_ = builder.extract_node( + f""" + def f1(x, y = {fail_val}): + if {condition}: # Does not filter out fail_val + return ( + y #@ + ) + + def f2(x, y = {satisfy_val}): + if {condition}: + return ( + y #@ + ) + """ + ) + for node, val in zip(nodes_, (fail_val, satisfy_val)): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_outside_if( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply outside of the if.""" + nodes_ = builder.extract_node( + f""" + def f1(x = {fail_val}): + if {condition}: + pass + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + if {condition}: + pass + + return ( + x #@ + ) + """ + ) + for node, val in zip(nodes_, (fail_val, satisfy_val)): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_nested_if( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition applies within inner if statements.""" + node1, node2 = builder.extract_node( + f""" + def f1(y, x = {fail_val}): + if {condition}: + if y is not None: + return ( + x #@ + ) + + def f2(y, x = {satisfy_val}): + if {condition}: + if y is not None: + return ( + x #@ + ) + """ + ) + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +def test_if_uninferable() -> None: + """Test that when no inferred values satisfy all constraints, Uninferable is inferred.""" + node1, node2 = builder.extract_node( + """ + def f1(): + x = None + if x is not None: + x #@ + + def f2(): + x = 1 + if x is not None: + pass + else: + x #@ + """ + ) + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +@common_params(node="x") +def test_if_reassignment_in_body( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is assigned to a failing value inside the if body. + """ + node = builder.extract_node( + f""" + def f(x, y): + if {condition}: + if y: + x = {fail_val} + return ( + x #@ + ) + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == fail_val + + +@common_params(node="x") +def test_if_elif_else_negates( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition is negated when the variable + is used in the elif and else branches. + """ + node1, node2, node3, node4 = builder.extract_node( + f""" + def f1(y, x = {fail_val}): + if {condition}: + pass + elif y: # Does not filter out default value + return ( + x #@ + ) + else: # Does not filter out default value + return ( + x #@ + ) + + def f2(y, x = {satisfy_val}): + if {condition}: + pass + elif y: # Filters out default value + return ( + x #@ + ) + else: # Filters out default value + return ( + x #@ + ) + """ + ) + for node in (node1, node2): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + for node in (node3, node4): + inferred = node.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +@common_params(node="x") +def test_if_reassignment_in_else( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is assigned to a failing value inside the else branch. + """ + node = builder.extract_node( + f""" + def f(x, y): + if {condition}: + return x + else: + if y: + x = {satisfy_val} + return ( + x #@ + ) + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == satisfy_val + + +@common_params(node="x") +def test_if_comprehension_shadow( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is shadowed by an inner comprehension scope. + """ + node = builder.extract_node( + f""" + def f(x): + if {condition}: + return [ + x #@ + for x in [{satisfy_val}, {fail_val}] + ] + """ + ) + # Hack for Python 3.7 where the ListComp starts on L5 instead of L4 + # Extract_node doesn't handle this correctly + if isinstance(node, nodes.ListComp): + node = node.elt + inferred = node.inferred() + assert len(inferred) == 2 + + for actual, expected in zip(inferred, (satisfy_val, fail_val)): + assert isinstance(actual, nodes.Const) + assert actual.value == expected + + +@common_params(node="x") +def test_if_function_shadow( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is shadowed by an inner function scope. + """ + node = builder.extract_node( + f""" + x = {satisfy_val} + if {condition}: + def f(x = {fail_val}): + return ( + x #@ + ) + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_function_call( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply for a parameter + a different function call, but with the same name. + """ + node = builder.extract_node( + f""" + def f(x = {satisfy_val}): + if {condition}: + g({fail_val}) #@ + + def g(x): + return x + """ + ) + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + +@common_params(node="self.x") +def test_if_instance_attr( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test constraint for an instance attribute in an if statement.""" + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self): + if {condition}: + self.x #@ + + class A2: + def __init__(self, x = {satisfy_val}): + self.x = x + + def method(self): + if {condition}: + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +@common_params(node="self.x") +def test_if_instance_attr_reassignment_in_body( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply to an instance attribute + when it is assigned inside the if body. + """ + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x): + self.x = x + + def method1(self): + if {condition}: + self.x = {satisfy_val} + self.x #@ + + def method2(self): + if {condition}: + self.x = {fail_val} + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == satisfy_val + + inferred = node2.inferred() + assert len(inferred) == 3 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == satisfy_val + + assert isinstance(inferred[2], nodes.Const) + assert inferred[2].value == fail_val + + +@common_params(node="x") +def test_if_instance_attr_varname_collision1( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply to an instance attribute + when the constraint refers to a variable with the same name. + """ + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self, x = {fail_val}): + if {condition}: + x #@ + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + +@common_params(node="self.x") +def test_if_instance_attr_varname_collision2( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply to a variable with the same name.""" + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self, x = {fail_val}): + if {condition}: + x #@ + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +@common_params(node="self.x") +def test_if_instance_attr_varname_collision3( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply to an instance attribute + for an object of a different class. + """ + node = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self): + obj = A2() + if {condition}: + obj.x #@ + + class A2: + def __init__(self): + self.x = {fail_val} + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + +@common_params(node="self.x") +def test_if_instance_attr_varname_collision4( + condition: str, satisfy_val: int | None, fail_val: int | None +) -> None: + """Test that constraint in an if condition doesn't apply to a variable of the same name, + when that variable is used to infer the value of the instance attribute. + """ + node = builder.extract_node( + f""" + class A1: + def __init__(self, x): + self.x = x + + def method(self): + x = {fail_val} + if {condition}: + self.x = x + self.x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == fail_val