From 159af87defeb6224fd4da0686b050e0c5de60bb3 Mon Sep 17 00:00:00 2001 From: David Liu Date: Sat, 28 Aug 2021 16:39:47 -0400 Subject: [PATCH 01/21] Support "is None" constraints from if statements during inference --- ChangeLog | 8 + astroid/bases.py | 20 +- astroid/constraint.py | 137 ++++++ astroid/context.py | 9 + astroid/inference.py | 11 +- tests/unittest_inference_constraints.py | 569 ++++++++++++++++++++++++ 6 files changed, 750 insertions(+), 4 deletions(-) create mode 100644 astroid/constraint.py create mode 100644 tests/unittest_inference_constraints.py diff --git a/ChangeLog b/ChangeLog index c4d155b28b..8955507940 100644 --- a/ChangeLog +++ b/ChangeLog @@ -6,6 +6,14 @@ What's New in astroid 2.9.0? ============================ Release date: TBA +* Support "is None" constraints from if statements during inference. + + Closes #791 + Closes PyCQA/pylint#157 + Closes PyCQA/pylint#1472 + Closes PyCQA/pylint#2016 + Closes PyCQA/pylint#2631 + Closes PyCQA/pylint#2880 What's New in astroid 2.8.1? diff --git a/astroid/bases.py b/astroid/bases.py index da4831b486..f5e72b7003 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -137,12 +137,15 @@ def infer(self, context=None): def _infer_stmts(stmts, context, frame=None): """Return an iterator on statements inferred by each statement in *stmts*.""" inferred = False + constraint_failed = False if context is not None: name = context.lookupname context = context.clone() + constraints = context.constraints.get(name, {}) else: name = None context = InferenceContext() + constraints = {} for stmt in stmts: if stmt is Uninferable: @@ -151,15 +154,26 @@ def _infer_stmts(stmts, context, frame=None): continue context.lookupname = stmt._infer_name(frame, name) try: + stmt_constraints = { + constraint + for constraint_stmt, constraint in constraints.items() + if not constraint_stmt.parent_of(stmt) + } for inf in stmt.infer(context=context): - yield inf - inferred = True + if all(constraint.satisfied_by(inf) for constraint in stmt_constraints): + yield inf + inferred = True + else: + constraint_failed = True except NameInferenceError: continue except InferenceError: yield Uninferable inferred = True - if not inferred: + + if not inferred and constraint_failed: + yield Uninferable + elif not inferred: raise InferenceError( "Inference failed for all members of {stmts!r}.", stmts=stmts, diff --git a/astroid/constraint.py b/astroid/constraint.py new file mode 100644 index 0000000000..dc4216c557 --- /dev/null +++ b/astroid/constraint.py @@ -0,0 +1,137 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +"""Classes representing different types of constraints on inference values.""" + +from typing import Any, Dict, Optional + +from astroid import nodes, util + + +class Constraint: + """Represents a single constraint on a variable.""" + + node: nodes.NodeNG + negate: bool + + def __init__(self, node: nodes.NodeNG, negate: bool) -> None: + self.node = node + """The node that this constraint applies to.""" + self.negate = negate + """True if this constraint is negated. E.g., "is not" instead of "is".""" + + def invert(self) -> None: + """Invert this constraint.""" + self.negate = not self.negate + + @classmethod + def match( + cls, node: nodes.NodeNG, expr: nodes.NodeNG, negate: bool = False + ) -> Optional["Constraint"]: + """Return a new constraint for node matched from expr, if expr matches + the constraint pattern. + + If negate is True, negate the constraint. + """ + + def satisfied_by(self, inferred: Any) -> bool: + """Return True if this constraint is satisfied by the given inferred value.""" + return True + + +class NoneConstraint(Constraint): + """Represents an "is None" or "is not None" constraint.""" + + @classmethod + def match( + cls, node: nodes.NodeNG, expr: nodes.NodeNG, negate: bool = False + ) -> Optional[Constraint]: + """Return a new constraint for node matched from expr, if expr matches + the constraint pattern. + + Negate the constraint based on the value of negate. + """ + if isinstance(expr, nodes.Compare) and len(expr.ops) == 1: + left = expr.left + op, right = expr.ops[0] + const_none = nodes.Const(None) + if op in {"is", "is not"} and ( + matches(left, node) + and matches(right, const_none) + or matches(left, const_none) + and matches(right, node) + ): + negate = (op == "is" and negate) or (op == "is not" and not negate) + return cls(node=node, negate=negate) + + return None + + def satisfied_by(self, inferred: Any) -> bool: + """Return True if this constraint is satisfied by the given inferred value.""" + if inferred is util.Uninferable: + return True + + if self.negate and matches(inferred, nodes.Const(None)): + return False + if not self.negate and not matches(inferred, nodes.Const(None)): + return False + + return True + + +def matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: + """Returns True if the two nodes match.""" + if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): + return node1.name == node2.name + if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute): + return node1.attrname == node2.attrname and matches(node1.expr, node2.expr) + if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const): + return node1.value == node2.value + + return False + + +def get_constraints( + expr: nodes.NodeNG, frame: nodes.NodeNG +) -> Dict[nodes.NodeNG, Constraint]: + """Returns the constraints for the given expression. + + The returned dictionary maps the node where the constraint was generated to the + corresponding constraint. + + Constraints are computed statically by analysing the code surrounding expr. + Currently this only supports constraints generated from if conditions. + """ + current_node = expr + constraints = {} + while current_node is not None and current_node is not frame: + parent = current_node.parent + if isinstance(parent, nodes.If): + branch, _ = parent.locate_child(current_node) + if branch == "body": + constraint = match_constraint(expr, parent.test) + elif branch == "orelse": + constraint = match_constraint(expr, parent.test, invert=True) + else: + constraint = None + + if constraint: + constraints[parent] = constraint + current_node = parent + + return constraints + + +ALL_CONSTRAINTS = (NoneConstraint,) +"""All supported constraint types.""" + + +def match_constraint( + node: nodes.NodeNG, expr: nodes.NodeNG, invert: bool = False +) -> Optional[Constraint]: + """Returns a constraint pattern for node, if one matches.""" + for constraint_cls in ALL_CONSTRAINTS: + constraint = constraint_cls.match(node, expr, invert) + if constraint: + return constraint + + return None diff --git a/astroid/context.py b/astroid/context.py index 9424813869..29b56490ff 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -41,6 +41,7 @@ class InferenceContext: "callcontext", "boundnode", "extra_context", + "constraints", "_nodes_inferred", ) @@ -91,6 +92,13 @@ def __init__(self, path=None, nodes_inferred=None): for call arguments """ + self.constraints = {} + """ + :type: dict + + The constraints on nodes + """ + @property def nodes_inferred(self): """ @@ -145,6 +153,7 @@ def clone(self): clone.callcontext = self.callcontext clone.boundnode = self.boundnode clone.extra_context = self.extra_context + clone.constraints = self.constraints.copy() return clone @contextlib.contextmanager diff --git a/astroid/inference.py b/astroid/inference.py index df8eff6fc5..18658ffe84 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -35,7 +35,9 @@ import wrapt -from astroid import bases, decorators, helpers, nodes, protocols, util +from astroid import bases +from astroid import constraint as constraintmod +from astroid import decorators, helpers, nodes, protocols, util from astroid.context import ( CallContext, InferenceContext, @@ -215,6 +217,8 @@ def infer_name(self, context=None): ) context = copy_context(context) context.lookupname = self.name + context.constraints[self.name] = constraintmod.get_constraints(self, frame) + return bases._infer_stmts(stmts, context, frame) @@ -317,6 +321,11 @@ def infer_attribute(self, context=None): old_boundnode = context.boundnode try: context.boundnode = owner + if isinstance(owner, (nodes.ClassDef, bases.Instance)): + frame = owner if isinstance(owner, nodes.ClassDef) else owner._proxied + context.constraints[self.attrname] = constraintmod.get_constraints( + self, frame=frame + ) yield from owner.igetattr(self.attrname, context) except ( AttributeInferenceError, diff --git a/tests/unittest_inference_constraints.py b/tests/unittest_inference_constraints.py new file mode 100644 index 0000000000..9c78defb8c --- /dev/null +++ b/tests/unittest_inference_constraints.py @@ -0,0 +1,569 @@ +"""Tests for inference involving constraints""" + +from typing import Any + +import pytest + +from astroid import builder, nodes +from astroid.util import Uninferable + + +def common_params(node): + return pytest.mark.parametrize( + ("condition", "satisfy_val", "fail_val"), + ( + (f"{node} is None", None, 3), + (f"{node} is not None", 3, None), + ), + ) + + +@common_params(node="x") +def test_if_single_statement(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test constraint for a variable that is used in the first statement of an if body.""" + node1, node2 = builder.extract_node( + f""" + def f1(x = {fail_val}): + if {condition}: # Filters out default value + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + if {condition}: # Does not filter out default value + return ( + x #@ + ) + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_multiple_statements( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test constraint for a variable that is used in an if body with multiple statements.""" + node1, node2 = builder.extract_node( + f""" + def f1(x = {fail_val}): + if {condition}: # Filters out default value + print(x) + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + if {condition}: # Does not filter out default value + print(x) + return ( + x #@ + ) + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_irrelevant_condition( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint for a different variable doesn't apply.""" + nodes_ = builder.extract_node( + f""" + def f1(x, y = {fail_val}): + if {condition}: # Does not filter out fail_val + return ( + y #@ + ) + + def f2(x, y = {satisfy_val}): + if {condition}: + return ( + y #@ + ) + """ + ) + for node, val in zip(nodes_, (fail_val, satisfy_val)): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_outside_if(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test that constraint in an if condition doesn't apply outside of the if.""" + nodes_ = builder.extract_node( + f""" + def f1(x = {fail_val}): + if {condition}: + pass + return ( + x #@ + ) + + def f2(x = {satisfy_val}): + if {condition}: + pass + + return ( + x #@ + ) + """ + ) + for node, val in zip(nodes_, (fail_val, satisfy_val)): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_nested_if(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test that constraint in an if condition applies within inner if statements.""" + node1, node2 = builder.extract_node( + f""" + def f1(y, x = {fail_val}): + if {condition}: + if y is not None: + return ( + x #@ + ) + + def f2(y, x = {satisfy_val}): + if {condition}: + if y is not None: + return ( + x #@ + ) + """ + ) + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +def test_if_uninferable() -> None: + """Test that when no inferred values satisfy all constraints, Uninferable is inferred.""" + node1, node2 = builder.extract_node( + """ + def f1(): + x = None + if x is not None: + x #@ + + def f2(): + x = 1 + if x is not None: + pass + else: + x #@ + """ + ) + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +@common_params(node="x") +def test_if_reassignment_in_body( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is assigned to a failing value inside the if body. + """ + node = builder.extract_node( + f""" + def f(x, y): + if {condition}: + if y: + x = {fail_val} + return ( + x #@ + ) + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == fail_val + + +@common_params(node="x") +def test_if_elif_else_negates(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test that constraint in an if condition is negated when the variable + is used in the elif and else branches. + """ + node1, node2, node3, node4 = builder.extract_node( + f""" + def f1(y, x = {fail_val}): + if {condition}: + pass + elif y: # Does not filter out default value + return ( + x #@ + ) + else: # Does not filter out default value + return ( + x #@ + ) + + def f2(y, x = {satisfy_val}): + if {condition}: + pass + elif y: # Filters out default value + return ( + x #@ + ) + else: # Filters out default value + return ( + x #@ + ) + """ + ) + for node in (node1, node2): + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + for node in (node3, node4): + inferred = node.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +@common_params(node="x") +def test_if_reassignment_in_else( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is assigned to a failing value inside the else branch. + """ + node = builder.extract_node( + f""" + def f(x, y): + if {condition}: + return x + else: + if y: + x = {satisfy_val} + return ( + x #@ + ) + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == satisfy_val + + +@common_params(node="x") +def test_if_comprehension_shadow( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is shadowed by an inner comprehension scope. + """ + node = builder.extract_node( + f""" + def f(x): + if {condition}: + return [ + x #@ + for x in [{satisfy_val}, {fail_val}] + ] + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + + for actual, expected in zip(inferred, (satisfy_val, fail_val)): + assert isinstance(actual, nodes.Const) + assert actual.value == expected + + +@common_params(node="x") +def test_if_function_shadow(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test that constraint in an if condition doesn't apply when the variable + is shadowed by an inner function scope. + """ + node = builder.extract_node( + f""" + x = {satisfy_val} + if {condition}: + def f(x = {fail_val}): + return ( + x #@ + ) + """ + ) + inferred = node.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + +@common_params(node="x") +def test_if_function_call(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test that constraint in an if condition doesn't apply for a parameter + a different function call, but with the same name. + """ + node = builder.extract_node( + f""" + def f(x = {satisfy_val}): + if {condition}: + g({fail_val}) #@ + + def g(x): + return x + """ + ) + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + +@common_params(node="self.x") +def test_if_instance_attr(condition: str, satisfy_val: Any, fail_val: Any) -> None: + """Test constraint for an instance attribute in an if statement.""" + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self): + if {condition}: + self.x #@ + + class A2: + def __init__(self, x = {satisfy_val}): + self.x = x + + def method(self): + if {condition}: + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == satisfy_val + + assert inferred[1] is Uninferable + + +@common_params(node="self.x") +def test_if_instance_attr_reassignment_in_body( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply to an instance attribute + when it is assigned inside the if body. + """ + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x): + self.x = x + + def method1(self): + if {condition}: + self.x = {satisfy_val} + self.x #@ + + def method2(self): + if {condition}: + self.x = {fail_val} + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == satisfy_val + + inferred = node2.inferred() + assert len(inferred) == 3 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == satisfy_val + + assert isinstance(inferred[2], nodes.Const) + assert inferred[2].value == fail_val + + +@common_params(node="x") +def test_if_instance_attr_varname_collision1( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply to an instance attribute + when the constraint refers to a variable with the same name. + """ + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self, x = {fail_val}): + if {condition}: + x #@ + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + +@common_params(node="self.x") +def test_if_instance_attr_varname_collision2( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply to a variable with the same name.""" + node1, node2 = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self, x = {fail_val}): + if {condition}: + x #@ + self.x #@ + """ + ) + + inferred = node1.inferred() + assert len(inferred) == 2 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + assert inferred[1] is Uninferable + + inferred = node2.inferred() + assert len(inferred) == 1 + assert inferred[0] is Uninferable + + +@common_params(node="self.x") +def test_if_instance_attr_varname_collision3( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply to an instance attribute + for an object of a different class. + """ + node = builder.extract_node( + f""" + class A1: + def __init__(self, x = {fail_val}): + self.x = x + + def method(self): + obj = A2() + if {condition}: + obj.x #@ + + class A2: + def __init__(self): + self.x = {fail_val} + """ + ) + + inferred = node.inferred() + assert len(inferred) == 1 + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == fail_val + + +@common_params(node="self.x") +def test_if_instance_attr_varname_collision4( + condition: str, satisfy_val: Any, fail_val: Any +) -> None: + """Test that constraint in an if condition doesn't apply to a variable of the same name, + when that variable is used to infer the value of the instance attribute. + """ + node = builder.extract_node( + f""" + class A1: + def __init__(self, x): + self.x = x + + def method(self): + x = {fail_val} + if {condition}: + self.x = x + self.x #@ + """ + ) + + inferred = node.inferred() + assert len(inferred) == 2 + assert inferred[0] is Uninferable + + assert isinstance(inferred[1], nodes.Const) + assert inferred[1].value == fail_val From 48287973db79cd1205eea66e63d448b621460da2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Mar 2022 20:26:21 +0100 Subject: [PATCH 02/21] Changes --- astroid/context.py | 9 ++-- tests/unittest_inference_constraints.py | 56 ++++++++++++++++--------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/astroid/context.py b/astroid/context.py index 70212e0244..a7d4879216 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple if TYPE_CHECKING: + from astroid import constraint from astroid.nodes.node_classes import Keyword, NodeNG _InferenceCache = Dict[ @@ -85,12 +86,8 @@ def __init__(self, path=None, nodes_inferred=None): for call arguments """ - self.constraints = {} - """ - :type: dict - - The constraints on nodes - """ + self.constraints: Dict[str, Dict["NodeNG", "constraint.Constraint"]] = {} + """The constraints on nodes.""" @property def nodes_inferred(self): diff --git a/tests/unittest_inference_constraints.py b/tests/unittest_inference_constraints.py index 9c78defb8c..5bf5ae1888 100644 --- a/tests/unittest_inference_constraints.py +++ b/tests/unittest_inference_constraints.py @@ -1,6 +1,10 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt + """Tests for inference involving constraints""" -from typing import Any +from typing import Optional import pytest @@ -8,7 +12,7 @@ from astroid.util import Uninferable -def common_params(node): +def common_params(node: str) -> pytest.MarkDecorator: return pytest.mark.parametrize( ("condition", "satisfy_val", "fail_val"), ( @@ -19,7 +23,9 @@ def common_params(node): @common_params(node="x") -def test_if_single_statement(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_if_single_statement( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test constraint for a variable that is used in the first statement of an if body.""" node1, node2 = builder.extract_node( f""" @@ -51,7 +57,7 @@ def f2(x = {satisfy_val}): @common_params(node="x") def test_if_multiple_statements( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test constraint for a variable that is used in an if body with multiple statements.""" node1, node2 = builder.extract_node( @@ -86,7 +92,7 @@ def f2(x = {satisfy_val}): @common_params(node="x") def test_if_irrelevant_condition( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint for a different variable doesn't apply.""" nodes_ = builder.extract_node( @@ -114,7 +120,9 @@ def f2(x, y = {satisfy_val}): @common_params(node="x") -def test_outside_if(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_outside_if( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test that constraint in an if condition doesn't apply outside of the if.""" nodes_ = builder.extract_node( f""" @@ -144,7 +152,9 @@ def f2(x = {satisfy_val}): @common_params(node="x") -def test_nested_if(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_nested_if( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test that constraint in an if condition applies within inner if statements.""" node1, node2 = builder.extract_node( f""" @@ -203,7 +213,7 @@ def f2(): @common_params(node="x") def test_if_reassignment_in_body( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply when the variable is assigned to a failing value inside the if body. @@ -228,7 +238,9 @@ def f(x, y): @common_params(node="x") -def test_if_elif_else_negates(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_if_elif_else_negates( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test that constraint in an if condition is negated when the variable is used in the elif and else branches. """ @@ -275,7 +287,7 @@ def f2(y, x = {satisfy_val}): @common_params(node="x") def test_if_reassignment_in_else( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply when the variable is assigned to a failing value inside the else branch. @@ -303,7 +315,7 @@ def f(x, y): @common_params(node="x") def test_if_comprehension_shadow( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply when the variable is shadowed by an inner comprehension scope. @@ -327,7 +339,9 @@ def f(x): @common_params(node="x") -def test_if_function_shadow(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_if_function_shadow( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test that constraint in an if condition doesn't apply when the variable is shadowed by an inner function scope. """ @@ -350,7 +364,9 @@ def f(x = {fail_val}): @common_params(node="x") -def test_if_function_call(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_if_function_call( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test that constraint in an if condition doesn't apply for a parameter a different function call, but with the same name. """ @@ -371,7 +387,9 @@ def g(x): @common_params(node="self.x") -def test_if_instance_attr(condition: str, satisfy_val: Any, fail_val: Any) -> None: +def test_if_instance_attr( + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] +) -> None: """Test constraint for an instance attribute in an if statement.""" node1, node2 = builder.extract_node( f""" @@ -407,7 +425,7 @@ def method(self): @common_params(node="self.x") def test_if_instance_attr_reassignment_in_body( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply to an instance attribute when it is assigned inside the if body. @@ -450,7 +468,7 @@ def method2(self): @common_params(node="x") def test_if_instance_attr_varname_collision1( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply to an instance attribute when the constraint refers to a variable with the same name. @@ -482,7 +500,7 @@ def method(self, x = {fail_val}): @common_params(node="self.x") def test_if_instance_attr_varname_collision2( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply to a variable with the same name.""" node1, node2 = builder.extract_node( @@ -512,7 +530,7 @@ def method(self, x = {fail_val}): @common_params(node="self.x") def test_if_instance_attr_varname_collision3( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply to an instance attribute for an object of a different class. @@ -542,7 +560,7 @@ def __init__(self): @common_params(node="self.x") def test_if_instance_attr_varname_collision4( - condition: str, satisfy_val: Any, fail_val: Any + condition: str, satisfy_val: Optional[int], fail_val: Optional[int] ) -> None: """Test that constraint in an if condition doesn't apply to a variable of the same name, when that variable is used to infer the value of the instance attribute. From c39bafffb4efc4f396863e4e5c42e08dbb5c57c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Mar 2022 20:29:27 +0100 Subject: [PATCH 03/21] Some typing --- astroid/bases.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/astroid/bases.py b/astroid/bases.py index 40b2ad6bbb..5f2e5fb13b 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -7,6 +7,7 @@ """ import collections +from typing import Optional from astroid import decorators from astroid.const import PY310_PLUS @@ -114,7 +115,7 @@ def infer(self, context=None): yield self -def _infer_stmts(stmts, context, frame=None): +def _infer_stmts(stmts, context: Optional[InferenceContext], frame=None): """Return an iterator on statements inferred by each statement in *stmts*.""" inferred = False constraint_failed = False From b54d841329278f368507aeddb9928b4ab2938f2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Mar 2022 19:31:29 +0000 Subject: [PATCH 04/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- astroid/constraint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/astroid/constraint.py b/astroid/constraint.py index dc4216c557..0537faec9e 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -1,5 +1,8 @@ # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE """Classes representing different types of constraints on inference values.""" from typing import Any, Dict, Optional From de4c31836cd98fc4e252e51caf868eadbb1731cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Mar 2022 20:54:49 +0100 Subject: [PATCH 05/21] Some changes to typing and removal of dead code --- astroid/constraint.py | 40 ++++++++++++++++++++-------------------- astroid/context.py | 4 ++-- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index 0537faec9e..1dbe5fd13a 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -1,42 +1,39 @@ # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt -# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html -# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE + """Classes representing different types of constraints on inference values.""" -from typing import Any, Dict, Optional +from typing import Dict, Optional, Type, TypeVar, Union from astroid import nodes, util +NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] +ConstraintT = TypeVar("ConstraintT", bound="Constraint") + class Constraint: """Represents a single constraint on a variable.""" - node: nodes.NodeNG - negate: bool - def __init__(self, node: nodes.NodeNG, negate: bool) -> None: self.node = node """The node that this constraint applies to.""" self.negate = negate """True if this constraint is negated. E.g., "is not" instead of "is".""" - def invert(self) -> None: - """Invert this constraint.""" - self.negate = not self.negate - @classmethod def match( - cls, node: nodes.NodeNG, expr: nodes.NodeNG, negate: bool = False - ) -> Optional["Constraint"]: + cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Optional[ConstraintT]: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. If negate is True, negate the constraint. """ - def satisfied_by(self, inferred: Any) -> bool: + def satisfied_by( + self, inferred: Union[nodes.NodeNG, Type[util.Uninferable]] + ) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" return True @@ -46,8 +43,8 @@ class NoneConstraint(Constraint): @classmethod def match( - cls, node: nodes.NodeNG, expr: nodes.NodeNG, negate: bool = False - ) -> Optional[Constraint]: + cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Optional[ConstraintT]: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. @@ -68,8 +65,11 @@ def match( return None - def satisfied_by(self, inferred: Any) -> bool: + def satisfied_by( + self, inferred: Union[nodes.NodeNG, Type[util.Uninferable]] + ) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" + # Assume true if uninferable if inferred is util.Uninferable: return True @@ -94,8 +94,8 @@ def matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: def get_constraints( - expr: nodes.NodeNG, frame: nodes.NodeNG -) -> Dict[nodes.NodeNG, Constraint]: + expr: NameNodes, frame: nodes.LocalsDictNodeNG +) -> Dict[nodes.If, Constraint]: """Returns the constraints for the given expression. The returned dictionary maps the node where the constraint was generated to the @@ -104,7 +104,7 @@ def get_constraints( Constraints are computed statically by analysing the code surrounding expr. Currently this only supports constraints generated from if conditions. """ - current_node = expr + current_node: nodes.NodeNG = expr constraints = {} while current_node is not None and current_node is not frame: parent = current_node.parent @@ -129,7 +129,7 @@ def get_constraints( def match_constraint( - node: nodes.NodeNG, expr: nodes.NodeNG, invert: bool = False + node: NameNodes, expr: nodes.NodeNG, invert: bool = False ) -> Optional[Constraint]: """Returns a constraint pattern for node, if one matches.""" for constraint_cls in ALL_CONSTRAINTS: diff --git a/astroid/context.py b/astroid/context.py index a7d4879216..6b98fac832 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple if TYPE_CHECKING: - from astroid import constraint + from astroid import constraint, nodes from astroid.nodes.node_classes import Keyword, NodeNG _InferenceCache = Dict[ @@ -86,7 +86,7 @@ def __init__(self, path=None, nodes_inferred=None): for call arguments """ - self.constraints: Dict[str, Dict["NodeNG", "constraint.Constraint"]] = {} + self.constraints: Dict[str, Dict["nodes.If", "constraint.Constraint"]] = {} """The constraints on nodes.""" @property From 3da1c83bf6d3b3d2569915c1ffaf96cfc8790f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Mar 2022 21:07:31 +0100 Subject: [PATCH 06/21] Fix on python 3.7 --- tests/unittest_inference_constraints.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unittest_inference_constraints.py b/tests/unittest_inference_constraints.py index 5bf5ae1888..1939a1b7a8 100644 --- a/tests/unittest_inference_constraints.py +++ b/tests/unittest_inference_constraints.py @@ -12,7 +12,7 @@ from astroid.util import Uninferable -def common_params(node: str) -> pytest.MarkDecorator: +def common_params(node: str) -> "pytest.MarkDecorator": return pytest.mark.parametrize( ("condition", "satisfy_val", "fail_val"), ( @@ -330,6 +330,10 @@ def f(x): ] """ ) + # Hack for Python 3.7 where the ListComp starts on L5 instead of L4 + # Extract_node doesn't handle this correctly + if isinstance(node, nodes.ListComp): + node = node.elt inferred = node.inferred() assert len(inferred) == 2 From 7fe82e6b42c0b8298adddaa08b47d812c0b8a18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Mar 2022 21:10:54 +0100 Subject: [PATCH 07/21] Fix import --- astroid/inference.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/astroid/inference.py b/astroid/inference.py index 683ab30003..21705038a9 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -25,9 +25,7 @@ import wrapt -from astroid import bases -from astroid import constraint as constraintmod -from astroid import decorators, helpers, nodes, protocols, util +from astroid import bases, constraint, decorators, helpers, nodes, protocols, util from astroid.context import ( CallContext, InferenceContext, @@ -216,7 +214,7 @@ def infer_name(self, context=None): ) context = copy_context(context) context.lookupname = self.name - context.constraints[self.name] = constraintmod.get_constraints(self, frame) + context.constraints[self.name] = constraint.get_constraints(self, frame) return bases._infer_stmts(stmts, context, frame) @@ -324,7 +322,7 @@ def infer_attribute(self, context=None): context.boundnode = owner if isinstance(owner, (nodes.ClassDef, bases.Instance)): frame = owner if isinstance(owner, nodes.ClassDef) else owner._proxied - context.constraints[self.attrname] = constraintmod.get_constraints( + context.constraints[self.attrname] = constraint.get_constraints( self, frame=frame ) yield from owner.igetattr(self.attrname, context) From fc4156e14b71f34494879be674f2e63461850c43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Mar 2022 21:27:29 +0100 Subject: [PATCH 08/21] Add abstractness --- astroid/constraint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index 1dbe5fd13a..00e4d7690e 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -4,6 +4,7 @@ """Classes representing different types of constraints on inference values.""" +from abc import abstractmethod from typing import Dict, Optional, Type, TypeVar, Union from astroid import nodes, util @@ -22,6 +23,7 @@ def __init__(self, node: nodes.NodeNG, negate: bool) -> None: """True if this constraint is negated. E.g., "is not" instead of "is".""" @classmethod + @abstractmethod def match( cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False ) -> Optional[ConstraintT]: @@ -31,11 +33,11 @@ def match( If negate is True, negate the constraint. """ + @abstractmethod def satisfied_by( self, inferred: Union[nodes.NodeNG, Type[util.Uninferable]] ) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" - return True class NoneConstraint(Constraint): From 55bb33f70def804e73e3157ebb3b1fa04c3f44c4 Mon Sep 17 00:00:00 2001 From: Pierre Sassoulas Date: Tue, 22 Mar 2022 06:46:10 +0100 Subject: [PATCH 09/21] Update ChangeLog --- ChangeLog | 1 - 1 file changed, 1 deletion(-) diff --git a/ChangeLog b/ChangeLog index 84ce9fd246..44fe0aa7f6 100644 --- a/ChangeLog +++ b/ChangeLog @@ -19,7 +19,6 @@ What's New in astroid 2.11.1? ============================= Release date: TBA - * Promoted ``getattr()`` from ``astroid.scoped_nodes.FunctionDef`` to its parent ``astroid.scoped_nodes.Lambda``. From 8a3b44322aa5874aa78fa98f5375a97a937d6267 Mon Sep 17 00:00:00 2001 From: David Liu Date: Wed, 23 Mar 2022 12:36:56 -0400 Subject: [PATCH 10/21] Simplify expression NoneConstraint.match --- astroid/constraint.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index 00e4d7690e..c84fd313c9 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -75,12 +75,7 @@ def satisfied_by( if inferred is util.Uninferable: return True - if self.negate and matches(inferred, nodes.Const(None)): - return False - if not self.negate and not matches(inferred, nodes.Const(None)): - return False - - return True + return self.negate ^ matches(inferred, nodes.Const(None)) def matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: From e032c0b08ab66563d0417186d6844b002f60e439 Mon Sep 17 00:00:00 2001 From: David Liu Date: Fri, 17 Jun 2022 16:18:46 -0400 Subject: [PATCH 11/21] move constraint.py into submodule, and small fixes --- astroid/constraint/__init__.py | 6 ++++ astroid/{ => constraint}/constraint.py | 35 +++++++++--------- ..._constraints.py => unittest_constraint.py} | 36 +++++++++---------- 3 files changed, 41 insertions(+), 36 deletions(-) create mode 100644 astroid/constraint/__init__.py rename astroid/{ => constraint}/constraint.py (86%) rename tests/{unittest_inference_constraints.py => unittest_constraint.py} (92%) diff --git a/astroid/constraint/__init__.py b/astroid/constraint/__init__.py new file mode 100644 index 0000000000..0cfba6a2f7 --- /dev/null +++ b/astroid/constraint/__init__.py @@ -0,0 +1,6 @@ +# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html +# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE +# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt + +"""Representations of logical constraints on values used during inference. +""" diff --git a/astroid/constraint.py b/astroid/constraint/constraint.py similarity index 86% rename from astroid/constraint.py rename to astroid/constraint/constraint.py index c84fd313c9..2d09121625 100644 --- a/astroid/constraint.py +++ b/astroid/constraint/constraint.py @@ -3,13 +3,17 @@ # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt """Classes representing different types of constraints on inference values.""" +from __future__ import annotations from abc import abstractmethod -from typing import Dict, Optional, Type, TypeVar, Union +from typing import TypeVar from astroid import nodes, util -NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] +__all__ = ["get_constraints"] + + +NameNodes = nodes.AssignAttr | nodes.Attribute | nodes.AssignName | nodes.Name ConstraintT = TypeVar("ConstraintT", bound="Constraint") @@ -26,7 +30,7 @@ def __init__(self, node: nodes.NodeNG, negate: bool) -> None: @abstractmethod def match( cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False - ) -> Optional[ConstraintT]: + ) -> ConstraintT | None: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. @@ -34,19 +38,19 @@ def match( """ @abstractmethod - def satisfied_by( - self, inferred: Union[nodes.NodeNG, Type[util.Uninferable]] - ) -> bool: + def satisfied_by(self, inferred: nodes.NodeNG | type[util.Uninferable]) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" class NoneConstraint(Constraint): """Represents an "is None" or "is not None" constraint.""" + CONST_NONE: nodes.Const = nodes.Const(None) + @classmethod def match( cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False - ) -> Optional[ConstraintT]: + ) -> ConstraintT | None: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. @@ -55,11 +59,10 @@ def match( if isinstance(expr, nodes.Compare) and len(expr.ops) == 1: left = expr.left op, right = expr.ops[0] - const_none = nodes.Const(None) if op in {"is", "is not"} and ( matches(left, node) - and matches(right, const_none) - or matches(left, const_none) + and matches(right, cls.CONST_NONE) + or matches(left, cls.CONST_NONE) and matches(right, node) ): negate = (op == "is" and negate) or (op == "is not" and not negate) @@ -67,9 +70,7 @@ def match( return None - def satisfied_by( - self, inferred: Union[nodes.NodeNG, Type[util.Uninferable]] - ) -> bool: + def satisfied_by(self, inferred: nodes.NodeNG | type[util.Uninferable]) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" # Assume true if uninferable if inferred is util.Uninferable: @@ -92,7 +93,7 @@ def matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: def get_constraints( expr: NameNodes, frame: nodes.LocalsDictNodeNG -) -> Dict[nodes.If, Constraint]: +) -> dict[nodes.If, Constraint]: """Returns the constraints for the given expression. The returned dictionary maps the node where the constraint was generated to the @@ -101,8 +102,8 @@ def get_constraints( Constraints are computed statically by analysing the code surrounding expr. Currently this only supports constraints generated from if conditions. """ - current_node: nodes.NodeNG = expr - constraints = {} + current_node = expr + constraints: dict[nodes.If, Constraint] = {} while current_node is not None and current_node is not frame: parent = current_node.parent if isinstance(parent, nodes.If): @@ -127,7 +128,7 @@ def get_constraints( def match_constraint( node: NameNodes, expr: nodes.NodeNG, invert: bool = False -) -> Optional[Constraint]: +) -> Constraint | None: """Returns a constraint pattern for node, if one matches.""" for constraint_cls in ALL_CONSTRAINTS: constraint = constraint_cls.match(node, expr, invert) diff --git a/tests/unittest_inference_constraints.py b/tests/unittest_constraint.py similarity index 92% rename from tests/unittest_inference_constraints.py rename to tests/unittest_constraint.py index 1939a1b7a8..4a1e0feb5f 100644 --- a/tests/unittest_inference_constraints.py +++ b/tests/unittest_constraint.py @@ -4,8 +4,6 @@ """Tests for inference involving constraints""" -from typing import Optional - import pytest from astroid import builder, nodes @@ -24,7 +22,7 @@ def common_params(node: str) -> "pytest.MarkDecorator": @common_params(node="x") def test_if_single_statement( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test constraint for a variable that is used in the first statement of an if body.""" node1, node2 = builder.extract_node( @@ -57,7 +55,7 @@ def f2(x = {satisfy_val}): @common_params(node="x") def test_if_multiple_statements( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test constraint for a variable that is used in an if body with multiple statements.""" node1, node2 = builder.extract_node( @@ -92,7 +90,7 @@ def f2(x = {satisfy_val}): @common_params(node="x") def test_if_irrelevant_condition( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint for a different variable doesn't apply.""" nodes_ = builder.extract_node( @@ -121,7 +119,7 @@ def f2(x, y = {satisfy_val}): @common_params(node="x") def test_outside_if( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply outside of the if.""" nodes_ = builder.extract_node( @@ -153,7 +151,7 @@ def f2(x = {satisfy_val}): @common_params(node="x") def test_nested_if( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition applies within inner if statements.""" node1, node2 = builder.extract_node( @@ -213,7 +211,7 @@ def f2(): @common_params(node="x") def test_if_reassignment_in_body( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply when the variable is assigned to a failing value inside the if body. @@ -239,7 +237,7 @@ def f(x, y): @common_params(node="x") def test_if_elif_else_negates( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition is negated when the variable is used in the elif and else branches. @@ -287,7 +285,7 @@ def f2(y, x = {satisfy_val}): @common_params(node="x") def test_if_reassignment_in_else( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply when the variable is assigned to a failing value inside the else branch. @@ -315,7 +313,7 @@ def f(x, y): @common_params(node="x") def test_if_comprehension_shadow( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply when the variable is shadowed by an inner comprehension scope. @@ -344,7 +342,7 @@ def f(x): @common_params(node="x") def test_if_function_shadow( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply when the variable is shadowed by an inner function scope. @@ -369,7 +367,7 @@ def f(x = {fail_val}): @common_params(node="x") def test_if_function_call( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply for a parameter a different function call, but with the same name. @@ -392,7 +390,7 @@ def g(x): @common_params(node="self.x") def test_if_instance_attr( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test constraint for an instance attribute in an if statement.""" node1, node2 = builder.extract_node( @@ -429,7 +427,7 @@ def method(self): @common_params(node="self.x") def test_if_instance_attr_reassignment_in_body( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply to an instance attribute when it is assigned inside the if body. @@ -472,7 +470,7 @@ def method2(self): @common_params(node="x") def test_if_instance_attr_varname_collision1( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply to an instance attribute when the constraint refers to a variable with the same name. @@ -504,7 +502,7 @@ def method(self, x = {fail_val}): @common_params(node="self.x") def test_if_instance_attr_varname_collision2( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply to a variable with the same name.""" node1, node2 = builder.extract_node( @@ -534,7 +532,7 @@ def method(self, x = {fail_val}): @common_params(node="self.x") def test_if_instance_attr_varname_collision3( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply to an instance attribute for an object of a different class. @@ -564,7 +562,7 @@ def __init__(self): @common_params(node="self.x") def test_if_instance_attr_varname_collision4( - condition: str, satisfy_val: Optional[int], fail_val: Optional[int] + condition: str, satisfy_val: int | None, fail_val: int | None ) -> None: """Test that constraint in an if condition doesn't apply to a variable of the same name, when that variable is used to infer the value of the instance attribute. From c1268a2228ee5e12967cca11316522fd179218df Mon Sep 17 00:00:00 2001 From: David Liu Date: Fri, 17 Jun 2022 16:22:43 -0400 Subject: [PATCH 12/21] Fix imports and __all__ --- astroid/constraint/__init__.py | 4 ++++ astroid/constraint/constraint.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/astroid/constraint/__init__.py b/astroid/constraint/__init__.py index 0cfba6a2f7..aad816b9f9 100644 --- a/astroid/constraint/__init__.py +++ b/astroid/constraint/__init__.py @@ -4,3 +4,7 @@ """Representations of logical constraints on values used during inference. """ + +from .constraint import get_constraints + +__all__ = ("get_constraints",) diff --git a/astroid/constraint/constraint.py b/astroid/constraint/constraint.py index 2d09121625..dfeb629a1e 100644 --- a/astroid/constraint/constraint.py +++ b/astroid/constraint/constraint.py @@ -10,7 +10,7 @@ from astroid import nodes, util -__all__ = ["get_constraints"] +__all__ = ("get_constraints",) NameNodes = nodes.AssignAttr | nodes.Attribute | nodes.AssignName | nodes.Name From 7d0b6b29541b2add2c813c67208cbf9bcb8327cc Mon Sep 17 00:00:00 2001 From: David Liu Date: Fri, 17 Jun 2022 16:26:15 -0400 Subject: [PATCH 13/21] Fix Union usage in type definition --- astroid/constraint/constraint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astroid/constraint/constraint.py b/astroid/constraint/constraint.py index dfeb629a1e..38695d5520 100644 --- a/astroid/constraint/constraint.py +++ b/astroid/constraint/constraint.py @@ -6,14 +6,14 @@ from __future__ import annotations from abc import abstractmethod -from typing import TypeVar +from typing import TypeVar, Union from astroid import nodes, util __all__ = ("get_constraints",) -NameNodes = nodes.AssignAttr | nodes.Attribute | nodes.AssignName | nodes.Name +NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] ConstraintT = TypeVar("ConstraintT", bound="Constraint") From f96372335ab8a31a87c788b464df0b94dbff60db Mon Sep 17 00:00:00 2001 From: David Liu Date: Fri, 17 Jun 2022 16:28:04 -0400 Subject: [PATCH 14/21] Add __future__ import to unittest_constraint.py --- tests/unittest_constraint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unittest_constraint.py b/tests/unittest_constraint.py index 4a1e0feb5f..cba506fdcb 100644 --- a/tests/unittest_constraint.py +++ b/tests/unittest_constraint.py @@ -3,6 +3,7 @@ # Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt """Tests for inference involving constraints""" +from __future__ import annotations import pytest @@ -10,7 +11,7 @@ from astroid.util import Uninferable -def common_params(node: str) -> "pytest.MarkDecorator": +def common_params(node: str) -> pytest.MarkDecorator: return pytest.mark.parametrize( ("condition", "satisfy_val", "fail_val"), ( From 65c3c558d89d9534b320e2ec97a60a4c01bfd82c Mon Sep 17 00:00:00 2001 From: David Liu Date: Mon, 25 Jul 2022 13:49:18 -0400 Subject: [PATCH 15/21] Updates based on review --- astroid/bases.py | 7 ++- astroid/{constraint => }/constraint.py | 71 ++++++++++++-------------- astroid/constraint/__init__.py | 10 ---- 3 files changed, 39 insertions(+), 49 deletions(-) rename astroid/{constraint => }/constraint.py (73%) delete mode 100644 astroid/constraint/__init__.py diff --git a/astroid/bases.py b/astroid/bases.py index 4326e2d6ca..03631ff067 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -11,7 +11,7 @@ import collections.abc import sys from collections.abc import Sequence -from typing import Any +from typing import TYPE_CHECKING, Any from astroid import decorators, nodes from astroid.const import PY310_PLUS @@ -35,6 +35,9 @@ else: from typing_extensions import Literal +if TYPE_CHECKING: + from astroid.constraint import Constraint + objectmodel = lazy_import("interpreter.objectmodel") helpers = lazy_import("helpers") manager = lazy_import("manager") @@ -147,6 +150,7 @@ def _infer_stmts( """Return an iterator on statements inferred by each statement in *stmts*.""" inferred = False constraint_failed = False + constraints: dict[str, dict[nodes.If, Constraint]] = {} if context is not None: name = context.lookupname context = context.clone() @@ -154,7 +158,6 @@ def _infer_stmts( else: name = None context = InferenceContext() - constraints = {} for stmt in stmts: if stmt is Uninferable: diff --git a/astroid/constraint/constraint.py b/astroid/constraint.py similarity index 73% rename from astroid/constraint/constraint.py rename to astroid/constraint.py index 38695d5520..9cf1cf5656 100644 --- a/astroid/constraint/constraint.py +++ b/astroid/constraint.py @@ -5,19 +5,16 @@ """Classes representing different types of constraints on inference values.""" from __future__ import annotations -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import TypeVar, Union from astroid import nodes, util -__all__ = ("get_constraints",) +_NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] +_ConstraintT = TypeVar("_ConstraintT", bound="Constraint") -NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] -ConstraintT = TypeVar("ConstraintT", bound="Constraint") - - -class Constraint: +class Constraint(ABC): """Represents a single constraint on a variable.""" def __init__(self, node: nodes.NodeNG, negate: bool) -> None: @@ -29,8 +26,8 @@ def __init__(self, node: nodes.NodeNG, negate: bool) -> None: @classmethod @abstractmethod def match( - cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False - ) -> ConstraintT | None: + cls: _ConstraintT, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> _ConstraintT | None: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. @@ -49,8 +46,8 @@ class NoneConstraint(Constraint): @classmethod def match( - cls: ConstraintT, node: NameNodes, expr: nodes.NodeNG, negate: bool = False - ) -> ConstraintT | None: + cls: _ConstraintT, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> _ConstraintT | None: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. @@ -60,10 +57,10 @@ def match( left = expr.left op, right = expr.ops[0] if op in {"is", "is not"} and ( - matches(left, node) - and matches(right, cls.CONST_NONE) - or matches(left, cls.CONST_NONE) - and matches(right, node) + _matches(left, node) + and _matches(right, cls.CONST_NONE) + or _matches(left, cls.CONST_NONE) + and _matches(right, node) ): negate = (op == "is" and negate) or (op == "is not" and not negate) return cls(node=node, negate=negate) @@ -76,23 +73,12 @@ def satisfied_by(self, inferred: nodes.NodeNG | type[util.Uninferable]) -> bool: if inferred is util.Uninferable: return True - return self.negate ^ matches(inferred, nodes.Const(None)) - - -def matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: - """Returns True if the two nodes match.""" - if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): - return node1.name == node2.name - if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute): - return node1.attrname == node2.attrname and matches(node1.expr, node2.expr) - if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const): - return node1.value == node2.value - - return False + # Return the XOR of self.negate and matches(inferred, self.CONST_NONE) + return self.negate ^ _matches(inferred, self.CONST_NONE) def get_constraints( - expr: NameNodes, frame: nodes.LocalsDictNodeNG + expr: _NameNodes, frame: nodes.LocalsDictNodeNG ) -> dict[nodes.If, Constraint]: """Returns the constraints for the given expression. @@ -108,12 +94,11 @@ def get_constraints( parent = current_node.parent if isinstance(parent, nodes.If): branch, _ = parent.locate_child(current_node) + constraint: Constraint | None = None if branch == "body": - constraint = match_constraint(expr, parent.test) + constraint = _match_constraint(expr, parent.test) elif branch == "orelse": - constraint = match_constraint(expr, parent.test, invert=True) - else: - constraint = None + constraint = _match_constraint(expr, parent.test, invert=True) if constraint: constraints[parent] = constraint @@ -122,15 +107,27 @@ def get_constraints( return constraints -ALL_CONSTRAINTS = (NoneConstraint,) +ALL_CONSTRAINT_CLASSES = (NoneConstraint,) """All supported constraint types.""" -def match_constraint( - node: NameNodes, expr: nodes.NodeNG, invert: bool = False +def _matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: + """Returns True if the two nodes match.""" + if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): + return node1.name == node2.name + if isinstance(node1, nodes.Attribute) and isinstance(node2, nodes.Attribute): + return node1.attrname == node2.attrname and _matches(node1.expr, node2.expr) + if isinstance(node1, nodes.Const) and isinstance(node2, nodes.Const): + return node1.value == node2.value + + return False + + +def _match_constraint( + node: _NameNodes, expr: nodes.NodeNG, invert: bool = False ) -> Constraint | None: """Returns a constraint pattern for node, if one matches.""" - for constraint_cls in ALL_CONSTRAINTS: + for constraint_cls in ALL_CONSTRAINT_CLASSES: constraint = constraint_cls.match(node, expr, invert) if constraint: return constraint diff --git a/astroid/constraint/__init__.py b/astroid/constraint/__init__.py deleted file mode 100644 index aad816b9f9..0000000000 --- a/astroid/constraint/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html -# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE -# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt - -"""Representations of logical constraints on values used during inference. -""" - -from .constraint import get_constraints - -__all__ = ("get_constraints",) From f6e001bcb359660373d0953d4aa325b098ebf0fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 20 Nov 2022 23:49:29 +0100 Subject: [PATCH 16/21] Use Self --- astroid/constraint.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index 9cf1cf5656..c7d98cb2aa 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -5,11 +5,17 @@ """Classes representing different types of constraints on inference values.""" from __future__ import annotations +import sys from abc import ABC, abstractmethod from typing import TypeVar, Union from astroid import nodes, util +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + _NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] _ConstraintT = TypeVar("_ConstraintT", bound="Constraint") @@ -26,8 +32,8 @@ def __init__(self, node: nodes.NodeNG, negate: bool) -> None: @classmethod @abstractmethod def match( - cls: _ConstraintT, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False - ) -> _ConstraintT | None: + cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. @@ -46,8 +52,8 @@ class NoneConstraint(Constraint): @classmethod def match( - cls: _ConstraintT, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False - ) -> _ConstraintT | None: + cls: type[Self], node: _NameNodes, expr: nodes.NodeNG, negate: bool = False + ) -> Self | None: """Return a new constraint for node matched from expr, if expr matches the constraint pattern. From 5da24fe5e3bdf86b4b54024d1e7ff302def608e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 20 Nov 2022 23:50:35 +0100 Subject: [PATCH 17/21] Use InferenceResult --- astroid/constraint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index c7d98cb2aa..8b9dcc6b58 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -10,6 +10,7 @@ from typing import TypeVar, Union from astroid import nodes, util +from astroid.typing import InferenceResult if sys.version_info >= (3, 11): from typing import Self @@ -41,7 +42,7 @@ def match( """ @abstractmethod - def satisfied_by(self, inferred: nodes.NodeNG | type[util.Uninferable]) -> bool: + def satisfied_by(self, inferred: InferenceResult) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" @@ -73,7 +74,7 @@ def match( return None - def satisfied_by(self, inferred: nodes.NodeNG | type[util.Uninferable]) -> bool: + def satisfied_by(self, inferred: InferenceResult) -> bool: """Return True if this constraint is satisfied by the given inferred value.""" # Assume true if uninferable if inferred is util.Uninferable: From b46f00149cf41ba2b8dfc9607f4d7ce0fa5baebc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 20 Nov 2022 23:53:19 +0100 Subject: [PATCH 18/21] Don't check None is --- astroid/constraint.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index 8b9dcc6b58..df8f7a093e 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -64,10 +64,7 @@ def match( left = expr.left op, right = expr.ops[0] if op in {"is", "is not"} and ( - _matches(left, node) - and _matches(right, cls.CONST_NONE) - or _matches(left, cls.CONST_NONE) - and _matches(right, node) + _matches(left, node) and _matches(right, cls.CONST_NONE) ): negate = (op == "is" and negate) or (op == "is not" and not negate) return cls(node=node, negate=negate) From 409543b3908fc2b62af5f6da2fcd9f3cc801aed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 20 Nov 2022 23:53:57 +0100 Subject: [PATCH 19/21] Use frozenset Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com> --- astroid/constraint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index df8f7a093e..e91ba0f352 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -111,7 +111,7 @@ def get_constraints( return constraints -ALL_CONSTRAINT_CLASSES = (NoneConstraint,) +ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,)) """All supported constraint types.""" From 316ba467813a688e164b2ac59907f189ba895c4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Sun, 20 Nov 2022 23:55:19 +0100 Subject: [PATCH 20/21] Also accept Proxy --- astroid/constraint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/astroid/constraint.py b/astroid/constraint.py index e91ba0f352..a861069118 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from typing import TypeVar, Union -from astroid import nodes, util +from astroid import bases, nodes, util from astroid.typing import InferenceResult if sys.version_info >= (3, 11): @@ -115,7 +115,7 @@ def get_constraints( """All supported constraint types.""" -def _matches(node1: nodes.NodeNG, node2: nodes.NodeNG) -> bool: +def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool: """Returns True if the two nodes match.""" if isinstance(node1, nodes.Name) and isinstance(node2, nodes.Name): return node1.name == node2.name From dabf0e93badb13b0aab8b66226467597f5d348a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 21 Nov 2022 00:13:53 +0100 Subject: [PATCH 21/21] Use an Iterator --- astroid/bases.py | 11 +++++------ astroid/constraint.py | 32 +++++++++++++++----------------- astroid/context.py | 2 +- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/astroid/bases.py b/astroid/bases.py index 03631ff067..c4780c7ac2 100644 --- a/astroid/bases.py +++ b/astroid/bases.py @@ -150,13 +150,13 @@ def _infer_stmts( """Return an iterator on statements inferred by each statement in *stmts*.""" inferred = False constraint_failed = False - constraints: dict[str, dict[nodes.If, Constraint]] = {} if context is not None: name = context.lookupname context = context.clone() constraints = context.constraints.get(name, {}) else: name = None + constraints = {} context = InferenceContext() for stmt in stmts: @@ -167,11 +167,10 @@ def _infer_stmts( # 'context' is always InferenceContext and Instances get '_infer_name' from ClassDef context.lookupname = stmt._infer_name(frame, name) # type: ignore[union-attr] try: - stmt_constraints = { - constraint - for constraint_stmt, constraint in constraints.items() - if not constraint_stmt.parent_of(stmt) - } + stmt_constraints: set[Constraint] = set() + for constraint_stmt, potential_constraints in constraints.items(): + if not constraint_stmt.parent_of(stmt): + stmt_constraints.update(potential_constraints) # Mypy doesn't recognize that 'stmt' can't be Uninferable for inf in stmt.infer(context=context): # type: ignore[union-attr] if all(constraint.satisfied_by(inf) for constraint in stmt_constraints): diff --git a/astroid/constraint.py b/astroid/constraint.py index a861069118..deed9ac52b 100644 --- a/astroid/constraint.py +++ b/astroid/constraint.py @@ -7,7 +7,8 @@ import sys from abc import ABC, abstractmethod -from typing import TypeVar, Union +from collections.abc import Iterator +from typing import Union from astroid import bases, nodes, util from astroid.typing import InferenceResult @@ -18,7 +19,6 @@ from typing_extensions import Self _NameNodes = Union[nodes.AssignAttr, nodes.Attribute, nodes.AssignName, nodes.Name] -_ConstraintT = TypeVar("_ConstraintT", bound="Constraint") class Constraint(ABC): @@ -83,32 +83,32 @@ def satisfied_by(self, inferred: InferenceResult) -> bool: def get_constraints( expr: _NameNodes, frame: nodes.LocalsDictNodeNG -) -> dict[nodes.If, Constraint]: +) -> dict[nodes.If, set[Constraint]]: """Returns the constraints for the given expression. The returned dictionary maps the node where the constraint was generated to the - corresponding constraint. + corresponding constraint(s). Constraints are computed statically by analysing the code surrounding expr. Currently this only supports constraints generated from if conditions. """ - current_node = expr - constraints: dict[nodes.If, Constraint] = {} + current_node: nodes.NodeNG | None = expr + constraints_mapping: dict[nodes.If, set[Constraint]] = {} while current_node is not None and current_node is not frame: parent = current_node.parent if isinstance(parent, nodes.If): branch, _ = parent.locate_child(current_node) - constraint: Constraint | None = None + constraints: set[Constraint] | None = None if branch == "body": - constraint = _match_constraint(expr, parent.test) + constraints = set(_match_constraint(expr, parent.test)) elif branch == "orelse": - constraint = _match_constraint(expr, parent.test, invert=True) + constraints = set(_match_constraint(expr, parent.test, invert=True)) - if constraint: - constraints[parent] = constraint + if constraints: + constraints_mapping[parent] = constraints current_node = parent - return constraints + return constraints_mapping ALL_CONSTRAINT_CLASSES = frozenset((NoneConstraint,)) @@ -129,11 +129,9 @@ def _matches(node1: nodes.NodeNG | bases.Proxy, node2: nodes.NodeNG) -> bool: def _match_constraint( node: _NameNodes, expr: nodes.NodeNG, invert: bool = False -) -> Constraint | None: - """Returns a constraint pattern for node, if one matches.""" +) -> Iterator[Constraint]: + """Yields all constraint patterns for node that match.""" for constraint_cls in ALL_CONSTRAINT_CLASSES: constraint = constraint_cls.match(node, expr, invert) if constraint: - return constraint - - return None + yield constraint diff --git a/astroid/context.py b/astroid/context.py index 9357c45bc9..7291e95140 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -82,7 +82,7 @@ def __init__(self, path=None, nodes_inferred=None): for call arguments """ - self.constraints: dict[str, dict[nodes.If, constraint.Constraint]] = {} + self.constraints: dict[str, dict[nodes.If, set[constraint.Constraint]]] = {} """The constraints on nodes.""" @property