Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for forward-references in dataclasses #10

Merged
merged 3 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions runtype/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@
import dataclasses
from typing import Union
from abc import ABC, abstractmethod
import inspect

from .utils import ForwardRef
from .common import CHECK_TYPES
from .validation import TypeMismatchError, ensure_isa as default_ensure_isa
from .pytypes import cast_to_type, SumType, NoneType
from .pytypes import TypeCaster, type_caster, SumType, NoneType

Required = object()
MAX_SAMPLE_SIZE = 16

class NopTypeCaster:
cache = {}
def to_canon(self, t):
return t

class Configuration(ABC):
"""Generic configuration template for dataclass. Mainly for type-checking.

Expand Down Expand Up @@ -48,15 +55,15 @@ class Form:

"""

def canonize_type(self, t):
"""Given a type, return its canonical form.
def on_default(self, default):
"""Called whenever a dataclass member is assigned a default value.
"""
return t
return default

def on_default(self, t, default):
"""Called whenever a dataclass member is assigned a default value.
def make_type_caster(self, frame):
"""Return a type caster, as defined in pytypes.TypeCaster
"""
return t, default
return NopTypeCaster()

@abstractmethod
def ensure_isa(self, a, b, sampler=None):
Expand All @@ -79,40 +86,48 @@ class PythonConfiguration(Configuration):

This is the default class given to the ``dataclass()`` function.
"""
canonize_type = staticmethod(cast_to_type)
make_type_caster = TypeCaster
ensure_isa = staticmethod(default_ensure_isa)

def cast(self, obj, to_type):
return to_type.cast_from(obj)

def on_default(self, type_, default):
if default is None:
type_ = SumType([type_, NoneType])
elif isinstance(default, (list, dict, set)):
def on_default(self, default):
if isinstance(default, (list, dict, set)):
def f(_=default):
return copy(_)
default = dataclasses.field(default_factory=f)
return type_, default

return dataclasses.field(default_factory=f)
return default


def _post_init(self, config, should_cast, sampler):
def _post_init(self, config, should_cast, sampler, type_caster):
for name, field in getattr(self, '__dataclass_fields__', {}).items():
value = getattr(self, name)

if value is Required:
raise TypeError(f"Field {name} requires a value")

try:
type_ = type_caster.cache[id(field)]
except KeyError:
type_ = field.type
if isinstance(type_, str):
type_ = ForwardRef(type_)
type_ = type_caster.to_canon(type_)
if field.default is None:
type_ = SumType([type_, NoneType])
type_caster.cache[id(field)] = type_

try:
if should_cast: # Basic cast
assert not sampler
value = config.cast(value, field.type)
value = config.cast(value, type_)
object.__setattr__(self, name, value)
else:
config.ensure_isa(value, field.type, sampler)
config.ensure_isa(value, type_, sampler)
except TypeMismatchError as e:
item_value, item_type = e.args
msg = f"[{type(self).__name__}] Attribute '{name}' expected value of type {field.type}."
msg = f"[{type(self).__name__}] Attribute '{name}' expected value of type '{type_}'."
msg += f" Instead got {value!r}"
if item_value is not value:
msg += f'\n\n Failed on item: {item_value!r}, expected type {item_type}'
Expand Down Expand Up @@ -197,9 +212,9 @@ def _sample(seq, max_sample_size=MAX_SAMPLE_SIZE):
return seq
return random.sample(seq, max_sample_size)

def _process_class(cls, config, check_types, **kw):
def _process_class(cls, config, check_types, context_frame, **kw):
for name, type_ in getattr(cls, '__annotations__', {}).items():
type_ = config.canonize_type(type_)
# type_ = config.type_to_canon(type_) if not isinstance(type_, str) else type_

# If default not specified, assign Required, for a later check
# We don't assign MISSING; we want to bypass dataclass which is too strict for this
Expand All @@ -211,7 +226,7 @@ def _process_class(cls, config, check_types, **kw):
if default.default is dataclasses.MISSING and default.default_factory is dataclasses.MISSING:
default.default = Required

type_, new_default = config.on_default(type_, default)
new_default = config.on_default(default)
if new_default is not default:
setattr(cls, name, new_default)

Expand All @@ -222,9 +237,12 @@ def _process_class(cls, config, check_types, **kw):

orig_post_init = getattr(cls, '__post_init__', None)
sampler = _sample if check_types=='sample' else None
# eval_type_string = EvalInContext(context_frame)
type_caster = config.make_type_caster(context_frame)

def __post_init__(self):
_post_init(self, config=config, should_cast=check_types == 'cast', sampler=sampler)
# Only now context_frame has complete information
_post_init(self, config=config, should_cast=check_types == 'cast', sampler=sampler, type_caster=type_caster)
if orig_post_init is not None:
orig_post_init(self)

Expand Down Expand Up @@ -340,8 +358,9 @@ def dataclass(cls=None, *, check_types: Union[bool, str] = CHECK_TYPES,
"""
assert isinstance(config, Configuration)

context_frame = inspect.currentframe().f_back # Get parent frame, to resolve forward-references
def wrap(cls):
return _process_class(cls, config, check_types,
return _process_class(cls, config, check_types, context_frame,
init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen, slots=slots)

# See if we're being called as @dataclass or @dataclass().
Expand Down
177 changes: 96 additions & 81 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import typing
from datetime import datetime

from .utils import ForwardRef
from .base_types import DataType, Validator, TypeMismatchError
from . import base_types
from . import datetime_parse
Expand All @@ -27,7 +28,7 @@ class PythonType(base_types.Type, Validator):

class Constraint(base_types.Constraint):
def __init__(self, for_type, predicates):
super().__init__(cast_to_type(for_type), predicates)
super().__init__(type_caster.to_canon(for_type), predicates)

def cast_from(self, obj):
obj = self.type.cast_from(obj)
Expand Down Expand Up @@ -194,7 +195,7 @@ def __init__(self, base, item=Any*Any):
super().__init__(base)
if isinstance(item, tuple):
assert len(item) == 2
item = ProductType([cast_to_type(x) for x in item])
item = ProductType([type_caster.to_canon(x) for x in item])
self.item = item

def validate_instance(self, obj, sampler=None):
Expand Down Expand Up @@ -329,85 +330,99 @@ def cast_from(self, obj):
origin_frozenset = typing.FrozenSet


def _cast_to_type(t):
if isinstance(t, Validator):
return t
class TypeCaster:
def __init__(self, frame=None):
self.cache = {}
self.frame = frame

if isinstance(t, tuple):
return SumType([cast_to_type(x) for x in t])
def _to_canon(self, t):
to_canon = self.to_canon

if isinstance(t, (base_types.Type, Validator)):
return t

if isinstance(t, ForwardRef):
t = t._evaluate(self.frame.f_globals, self.frame.f_locals, set())

if isinstance(t, tuple):
return SumType([to_canon(x) for x in t])

try:
t.__origin__
except AttributeError:
pass
else:
if getattr(t, '__args__', None) is None:
if t is typing.List:
return List
elif t is typing.Dict:
return Dict
elif t is typing.Set:
return Set
elif t is typing.FrozenSet:
return FrozenSet
elif t is typing.Tuple:
return Tuple
elif t is typing.Mapping: # 3.6
return Mapping
elif t is typing.Sequence:
return Sequence

if t.__origin__ is origin_list:
x ,= t.__args__
return List[to_canon(x)]
elif t.__origin__ is origin_set:
x ,= t.__args__
return Set[to_canon(x)]
elif t.__origin__ is origin_frozenset:
x ,= t.__args__
return FrozenSet[to_canon(x)]
elif t.__origin__ is origin_dict:
k, v = t.__args__
return Dict[to_canon(k), to_canon(v)]
elif t.__origin__ is origin_tuple:
if Ellipsis in t.__args__:
if len(t.__args__) != 2 or t.__args__[0] == Ellipsis:
raise ValueError("Tuple with '...'' expected to be of the exact form: tuple[t, ...].")
return TupleEllipsis[to_canon(t.__args__[0])]

return ProductType([to_canon(x) for x in t.__args__])

elif t.__origin__ is typing.Union:
return SumType([to_canon(x) for x in t.__args__])
elif t.__origin__ is abc.Callable or t is typing.Callable:
# return Callable[ProductType(to_canon(x) for x in t.__args__)]
return Callable # TODO
elif py38 and t.__origin__ is typing.Literal:
return OneOf(t.__args__)
elif t.__origin__ is abc.Mapping or t.__origin__ is typing.Mapping:
k, v = t.__args__
return Mapping[to_canon(k), to_canon(v)]
elif t.__origin__ is abc.Sequence or t.__origin__ is typing.Sequence:
x ,= t.__args__
return Sequence[to_canon(x)]

elif t.__origin__ is type or t.__origin__ is typing.Type:
# TODO test issubclass on t.__args__
return PythonDataType(type)

raise NotImplementedError("No support for type:", t)

if isinstance(t, typing.TypeVar):
return Any # XXX is this correct?

return PythonDataType(t)

def to_canon(self, t):
try:
return self.cache[t]
except KeyError:
try:
res = _type_cast_mapping[t]
except KeyError:
res = self._to_canon(t)
self.cache[t] = res # memoize
return res

try:
t.__origin__
except AttributeError:
pass
else:
if getattr(t, '__args__', None) is None:
if t is typing.List:
return List
elif t is typing.Dict:
return Dict
elif t is typing.Set:
return Set
elif t is typing.FrozenSet:
return FrozenSet
elif t is typing.Tuple:
return Tuple
elif t is typing.Mapping: # 3.6
return Mapping
elif t is typing.Sequence:
return Sequence

if t.__origin__ is origin_list:
x ,= t.__args__
return List[cast_to_type(x)]
elif t.__origin__ is origin_set:
x ,= t.__args__
return Set[cast_to_type(x)]
elif t.__origin__ is origin_frozenset:
x ,= t.__args__
return FrozenSet[cast_to_type(x)]
elif t.__origin__ is origin_dict:
k, v = t.__args__
return Dict[cast_to_type(k), cast_to_type(v)]
elif t.__origin__ is origin_tuple:
if Ellipsis in t.__args__:
if len(t.__args__) != 2 or t.__args__[0] == Ellipsis:
raise ValueError("Tuple with '...'' expected to be of the exact form: tuple[t, ...].")
return TupleEllipsis[cast_to_type(t.__args__[0])]

return ProductType([cast_to_type(x) for x in t.__args__])

elif t.__origin__ is typing.Union:
return SumType([cast_to_type(x) for x in t.__args__])
elif t.__origin__ is abc.Callable or t is typing.Callable:
# return Callable[ProductType(cast_to_type(x) for x in t.__args__)]
return Callable # TODO
elif py38 and t.__origin__ is typing.Literal:
return OneOf(t.__args__)
elif t.__origin__ is abc.Mapping or t.__origin__ is typing.Mapping:
k, v = t.__args__
return Mapping[cast_to_type(k), cast_to_type(v)]
elif t.__origin__ is abc.Sequence or t.__origin__ is typing.Sequence:
x ,= t.__args__
return Sequence[_cast_to_type(x)]

elif t.__origin__ is type or t.__origin__ is typing.Type:
# TODO test issubclass on t.__args__
return PythonDataType(type)

raise NotImplementedError("No support for type:", t)

if isinstance(t, typing.TypeVar):
return Any # XXX is this correct?

return PythonDataType(t)


def cast_to_type(t):
try:
return _type_cast_mapping[t]
except KeyError:
res = _cast_to_type(t)
_type_cast_mapping[t] = res # memoize
return res

type_caster = TypeCaster()
18 changes: 18 additions & 0 deletions runtype/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
import inspect
import sys

if sys.version_info < (3, 7):
# python 3.6
from typing import _ForwardRef as ForwardRef
_orig_eval = ForwardRef._eval_type
elif sys.version_info < (3, 9):
from typing import ForwardRef
_orig_eval = ForwardRef._evaluate
else:
from typing import ForwardRef

if sys.version_info < (3, 9):
def _evaluate(self, g, l, _):
return _orig_eval(self, g, l)
ForwardRef._evaluate = _evaluate



def get_func_signatures(typesystem, f):
sig = inspect.signature(f)
Expand Down
Loading