Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle inference of ComprehensionScope better #1475

Closed
wants to merge 11 commits into from
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ What's New in astroid 2.12.0?
=============================
Release date: TBA

* Infer comprehensions and generators as their respective nodes instead of
as ``Uninferable``.

Closes #135, #1404


What's New in astroid 2.11.2?
Expand Down
7 changes: 6 additions & 1 deletion astroid/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""


from typing import Sequence

from astroid import bases, manager, nodes, raw_building, util
from astroid.context import CallContext, InferenceContext
from astroid.exceptions import (
Expand Down Expand Up @@ -55,6 +57,9 @@ def _object_type(node, context=None):
yield _function_type(inferred, builtins)
elif isinstance(inferred, scoped_nodes.Module):
yield _build_proxy_class("module", builtins)
# TODO: Implement type() lookup for ComprehenscopScope, perhaps through _proxied
elif isinstance(inferred, nodes.ComprehensionScope):
raise InferenceError
else:
yield inferred._proxied

Expand All @@ -80,7 +85,7 @@ def object_type(node, context=None):

def _object_type_is_subclass(obj_type, class_or_seq, context=None):
if not isinstance(class_or_seq, (tuple, list)):
class_seq = (class_or_seq,)
class_seq: Sequence = (class_or_seq,)
else:
class_seq = class_or_seq

Expand Down
21 changes: 20 additions & 1 deletion astroid/nodes/scoped_nodes/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

"""This module contains mixin classes for scoped nodes."""

from typing import TYPE_CHECKING, Dict, List, TypeVar
import abc
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, TypeVar

from astroid.filter_statements import _filter_stmts
from astroid.nodes import node_classes, scoped_nodes
from astroid.nodes.scoped_nodes.utils import builtin_lookup

if TYPE_CHECKING:
from astroid import nodes
from astroid.context import InferenceContext


_T = TypeVar("_T")

Expand Down Expand Up @@ -170,3 +173,19 @@ class ComprehensionScope(LocalsDictNodeNG):
"""Scoping for different types of comprehensions."""

scope_lookup = LocalsDictNodeNG._scope_lookup

generators: List["nodes.Comprehension"]
"""The generators that are looped through."""

def qname(self) -> str:
"""Get the 'qualified' name of the node."""
return self.pytype()
jacobtylerwalls marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +179 to +181
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cdce8p do you have thoughts about this? In the past I believe you were a vote against, see #1284 (review).


def infer(
self: _T, context: Optional["InferenceContext"] = None, **kwargs: Any
) -> Iterator[_T]:
yield self

@abc.abstractmethod
def pytype(self) -> str:
pass
56 changes: 26 additions & 30 deletions astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import typing
import warnings
from typing import Dict, List, Optional, Set, TypeVar, Union, overload
from typing import TYPE_CHECKING, Dict, List, Optional, Set, TypeVar, Union, overload

from astroid import bases
from astroid import decorators as decorators_mod
Expand Down Expand Up @@ -58,6 +58,9 @@

from astroid.decorators import cachedproperty as cached_property

if TYPE_CHECKING:
from astroid import nodes


ITER_METHODS = ("__iter__", "__getitem__")
EXCEPTION_BASE_CLASSES = frozenset({"Exception", "BaseException"})
Expand Down Expand Up @@ -674,11 +677,6 @@ class GeneratorExp(ComprehensionScope):

:type: NodeNG or None
"""
generators = None
"""The generators that are looped through.

:type: list(Comprehension) or None
"""

def __init__(
self,
Expand Down Expand Up @@ -721,14 +719,13 @@ def __init__(
parent=parent,
)

def postinit(self, elt=None, generators=None):
def postinit(self, elt=None, generators: Optional["nodes.Comprehension"] = None):
jacobtylerwalls marked this conversation as resolved.
Show resolved Hide resolved
"""Do some setup after initialisation.

:param elt: The element that forms the output of the expression.
:type elt: NodeNG or None

:param generators: The generators that are looped through.
:type generators: list(Comprehension) or None
"""
self.elt = elt
if generators is None:
Expand All @@ -750,6 +747,9 @@ def get_children(self):

yield from self.generators

def pytype(self) -> str:
return "builtins.generator"


class DictComp(ComprehensionScope):
"""Class representing an :class:`ast.DictComp` node.
Expand All @@ -772,11 +772,6 @@ class DictComp(ComprehensionScope):

:type: NodeNG or None
"""
generators = None
"""The generators that are looped through.

:type: list(Comprehension) or None
"""

def __init__(
self,
Expand Down Expand Up @@ -819,7 +814,9 @@ def __init__(
parent=parent,
)

def postinit(self, key=None, value=None, generators=None):
def postinit(
self, key=None, value=None, generators: Optional["nodes.Comprehension"] = None
):
"""Do some setup after initialisation.

:param key: What produces the keys.
Expand All @@ -829,7 +826,6 @@ def postinit(self, key=None, value=None, generators=None):
:type value: NodeNG or None

:param generators: The generators that are looped through.
:type generators: list(Comprehension) or None
"""
self.key = key
self.value = value
Expand All @@ -853,6 +849,9 @@ def get_children(self):

yield from self.generators

def pytype(self) -> str:
return "builtins.dict"


class SetComp(ComprehensionScope):
"""Class representing an :class:`ast.SetComp` node.
Expand All @@ -870,11 +869,6 @@ class SetComp(ComprehensionScope):

:type: NodeNG or None
"""
generators = None
"""The generators that are looped through.

:type: list(Comprehension) or None
"""

def __init__(
self,
Expand Down Expand Up @@ -917,14 +911,13 @@ def __init__(
parent=parent,
)

def postinit(self, elt=None, generators=None):
def postinit(self, elt=None, generators: Optional["nodes.Comprehension"] = None):
"""Do some setup after initialisation.

:param elt: The element that forms the output of the expression.
:type elt: NodeNG or None

:param generators: The generators that are looped through.
:type generators: list(Comprehension) or None
"""
self.elt = elt
if generators is None:
Expand All @@ -946,6 +939,9 @@ def get_children(self):

yield from self.generators

def pytype(self) -> str:
return "builtins.set"


class ListComp(ComprehensionScope):
"""Class representing an :class:`ast.ListComp` node.
Expand All @@ -965,12 +961,6 @@ class ListComp(ComprehensionScope):
:type: NodeNG or None
"""

generators = None
"""The generators that are looped through.

:type: list(Comprehension) or None
"""

def __init__(
self,
lineno=None,
Expand All @@ -994,7 +984,7 @@ def __init__(
parent=parent,
)

def postinit(self, elt=None, generators=None):
def postinit(self, elt=None, generators: Optional["nodes.Comprehension"] = None):
"""Do some setup after initialisation.

:param elt: The element that forms the output of the expression.
Expand All @@ -1004,7 +994,10 @@ def postinit(self, elt=None, generators=None):
:type generators: list(Comprehension) or None
"""
self.elt = elt
self.generators = generators
if generators is None:
DanielNoord marked this conversation as resolved.
Show resolved Hide resolved
self.generators = []
else:
self.generators = generators

def bool_value(self, context=None):
"""Determine the boolean value of this node.
Expand All @@ -1020,6 +1013,9 @@ def get_children(self):

yield from self.generators

def pytype(self) -> str:
return "builtins.list"


def _infer_decorator_callchain(node):
"""Detect decorator call chaining and see if the end result is a
Expand Down
9 changes: 4 additions & 5 deletions tests/unittest_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2645,11 +2645,11 @@ def true_value():
klass = module["Class"]
self.assertTrue(klass.bool_value())
dict_comp = next(module["dict_comp"].infer())
self.assertEqual(dict_comp, util.Uninferable)
assert isinstance(dict_comp, nodes.DictComp)
set_comp = next(module["set_comp"].infer())
self.assertEqual(set_comp, util.Uninferable)
assert isinstance(set_comp, nodes.SetComp)
list_comp = next(module["list_comp"].infer())
self.assertEqual(list_comp, util.Uninferable)
assert isinstance(list_comp, nodes.ListComp)
Comment on lines 2644 to +2649
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm misunderstanding the concept of inference here. If so, please correct me.

However, the way I see it inference means we look at the AST node and tell the caller what the Python result will be without actually executing the code. To give an example

lst = [x for x in range(3)]

The inference result for lst shouldn't be the list comprehension itself. Instead it should be the AST node for a list with

[0, 1, 2]

Same for generator expressions, dict and set comprehensions.

That is much much harder to do which is likely we it wasn't done already.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, yeah I didn't think of it like that. That would decouple the value of lst from the actual ast tree though.

For example, we would then have a ListComp node in nodes.Module.body but a List node for lst? If we create that List when calling _infer on ListComp what would then be the lineno etc.?

Wouldn't it be better to say that lst is a ListComp and let downstream users decide if they want to handle ListComp and List similarly? Like we often do with FunctionDef and AsyncFunctionDef.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would decouple the value of lst from the actual ast tree though.

Yeah, it's an inference result after all.

If we create that List when calling _infer on ListComp what would then be the lineno etc.?

None? Both lineno and col_offset attributes are typed as optional.
Or the same one as the original ListComp.

Wouldn't it be better to say that lst is a ListComp and let downstream users decide if they want to handle ListComp and List similarly? Like we often do with FunctionDef and AsyncFunctionDef.

The downstream user still has full control over it. If he once the "raw" data, he can just use the node itself. No need to call infer for that. Regarding FunctionDef, the infer method can actually return a Property if the function is decorated. Similarly, inferring a DictUnpack will return a single dict or a comparison just a bool constant

{1: 'A', **{2: 'B'}}
True == True

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should probably handle this comprehension per comprehension then. Instead of one larger PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should probably handle this comprehension per comprehension then. Instead of one larger PR.

Yeah. There are multiple parts to be aware of: evaluating the iterator, matching the loop var(s) with optional sequence unpacking, checking any if statements and doing late assignments via :=, evaluating the loop expression. Likely I'm forgetting something.

[a + c for a, b in func() if (c := b * 2)]

If you, or someone else wants to go down that route, it's probably best to start with a fairly limited set constraints and only one comprehension type. Once that is down extending it should be fairly easy in comparison.

lambda_func = next(module["lambda_func"].infer())
self.assertTrue(lambda_func)
unbound_method = next(module["unbound_method"].infer())
Expand Down Expand Up @@ -4195,8 +4195,7 @@ class Test(Parent):

def test_uninferable_type_subscript(self) -> None:
node = extract_node("[type for type in [] if type['id']]")
with self.assertRaises(InferenceError):
_ = next(node.infer())
assert isinstance(node.inferred()[0], nodes.ListComp)


class GetattrTest(unittest.TestCase):
Expand Down