Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-119180: Improvements to ForwardRef.evaluate #122210

Merged
merged 3 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init_subclass__(cls, /, *args, **kwds):
def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
"""Evaluate the forward reference and return the value.

If the forward reference is not evaluatable, raise an exception.
If the forward reference cannot be evaluated, raise an exception.
"""
if self.__forward_evaluated__:
return self.__forward_value__
Expand All @@ -89,12 +89,10 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
return value
if owner is None:
owner = self.__owner__
if type_params is None and owner is None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check felt unnecessarily strict; many ForwardRefs can be evaluated correctly without worrying about type params.

raise TypeError("Either 'type_params' or 'owner' must be provided")

if self.__forward_module__ is not None:
if globals is None and self.__forward_module__ is not None:
globals = getattr(
sys.modules.get(self.__forward_module__, None), "__dict__", globals
sys.modules.get(self.__forward_module__, None), "__dict__", None
)
if globals is None:
globals = self.__globals__
Expand All @@ -112,14 +110,14 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):

if locals is None:
locals = {}
if isinstance(self.__owner__, type):
locals.update(vars(self.__owner__))
if isinstance(owner, type):
locals.update(vars(owner))

if type_params is None and self.__owner__ is not None:
if type_params is None and owner is not None:
# "Inject" type parameters into the local namespace
# (unless they are shadowed by assignments *in* the local namespace),
# as a way of emulating annotation scopes when calling `eval()`
type_params = getattr(self.__owner__, "__type_params__", None)
type_params = getattr(owner, "__type_params__", None)

# type parameters require some special handling,
# as they exist in their own scope
Expand All @@ -129,7 +127,14 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
# but should in turn be overridden by names in the class scope
# (which here are called `globalns`!)
if type_params is not None:
globals, locals = dict(globals), dict(locals)
if globals is None:
globals = {}
else:
globals = dict(globals)
if locals is None:
locals = {}
else:
locals = dict(locals)
for param in type_params:
param_name = param.__name__
if not self.__forward_is_class__ or param_name not in globals:
Expand Down Expand Up @@ -413,7 +418,7 @@ def __missing__(self, key):
return fwdref


def call_annotate_function(annotate, format, owner=None):
def call_annotate_function(annotate, format, *, owner=None):
"""Call an __annotate__ function. __annotate__ functions are normally
generated by the compiler to defer the evaluation of annotations. They
can be called with any of the format arguments in the Format enum, but
Expand Down
41 changes: 41 additions & 0 deletions Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import pickle
import unittest
from annotationlib import Format, ForwardRef
from typing import Unpack

from test.test_inspect import inspect_stock_annotations
Expand Down Expand Up @@ -248,6 +249,46 @@ def test_special_attrs(self):
with self.assertRaises(TypeError):
pickle.dumps(fr, proto)

def test_evaluate_with_type_params(self):
class Gen[T]:
alias = int

with self.assertRaises(NameError):
ForwardRef("T").evaluate()
with self.assertRaises(NameError):
ForwardRef("T").evaluate(type_params=())
with self.assertRaises(NameError):
ForwardRef("T").evaluate(owner=int)

T, = Gen.__type_params__
self.assertIs(ForwardRef("T").evaluate(type_params=Gen.__type_params__), T)
self.assertIs(ForwardRef("T").evaluate(owner=Gen), T)

with self.assertRaises(NameError):
ForwardRef("alias").evaluate(type_params=Gen.__type_params__)
self.assertIs(ForwardRef("alias").evaluate(owner=Gen), int)
# If you pass custom locals, we don't look at the owner's locals
with self.assertRaises(NameError):
ForwardRef("alias").evaluate(owner=Gen, locals={})
# But if the name exists in the locals, it works
self.assertIs(
ForwardRef("alias").evaluate(owner=Gen, locals={"alias": str}), str
)

def test_fwdref_with_module(self):
self.assertIs(ForwardRef("Format", module=annotationlib).evaluate(), Format)

with self.assertRaises(NameError):
# If globals are passed explicitly, we don't look at the module dict
ForwardRef("Format", module=annotationlib).evaluate(globals={})

def test_fwdref_value_is_cached(self):
fr = ForwardRef("hello")
with self.assertRaises(NameError):
fr.evaluate()
self.assertIs(fr.evaluate(globals={"hello": str}), str)
self.assertIs(fr.evaluate(), str)


class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):
Expand Down
Loading