Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve OneOf comparisons and its interaction with Union #22

Merged
merged 2 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def validate_instance(self, obj, sampler=None):


class SumType(base_types.SumType, PythonType):
def __init__(self, types):
# Here we merge all the instances of OneOf into a single one (if necessary).
# The alternative is to turn all OneOf instances into SumTypes of single values.
# I chose this method due to intuition that it's faster for the common use-case.
one_ofs: List[OneOf] = [t for t in types if isinstance(t, OneOf)]
if len(one_ofs) > 1:
rest = [t for t in types if not isinstance(t, OneOf)]
types = rest + [OneOf([v for t in one_ofs for v in t.values])]
super().__init__(types)

def validate_instance(self, obj, sampler=None):
if not any(t.test_instance(obj) for t in self.types):
raise TypeMismatchError(obj, self)
Expand Down Expand Up @@ -167,12 +177,28 @@ def validate_instance(self, obj, sampler=None):


class OneOf(PythonType):
values: typing.Sequence

def __init__(self, values):
self.values = values

def __le__(self, other):
if isinstance(other, OneOf):
return set(self.values) <= set(other.values)
elif isinstance(other, PythonDataType):
try:
for v in self.values:
other.validate_instance(v)
except TypeMismatchError as e:
return False
return True
return NotImplemented

def __ge__(self, other):
if isinstance(other, OneOf):
return set(self.values) >= set(other.values)
elif isinstance(other, PythonDataType):
return False
return NotImplemented

def validate_instance(self, obj, sampler=None):
Expand All @@ -182,6 +208,10 @@ def validate_instance(self, obj, sampler=None):
def __repr__(self):
return 'Literal[%s]' % ', '.join(map(repr, self.values))

def cast_from(self, obj):
if obj not in self.values:
raise TypeMismatchError(obj, self)


class GenericType(base_types.GenericType, PythonType):
def __init__(self, base, item=Any):
Expand Down Expand Up @@ -264,11 +294,6 @@ def cast_from(self, obj):
Literal = OneOf


class _NoneType(PythonDataType):
def cast_from(self, obj):
if obj is not None:
raise TypeMismatchError(obj, self)

class _Number(PythonDataType):
def __call__(self, min=None, max=None):
predicates = []
Expand Down Expand Up @@ -312,10 +337,20 @@ def cast_from(self, obj):
return super().cast_from(obj)


class _NoneType(OneOf):
def __init__(self):
super().__init__([None])
def cast_from(self, obj):
assert self.values == [None]
if obj is not None:
raise TypeMismatchError(obj, self)
return None


String = _String(str)
Int = _Int(int)
Float = _Float(float)
NoneType = _NoneType(type(None))
NoneType = _NoneType()
DateTime = _DateTime(datetime)


Expand Down Expand Up @@ -421,7 +456,8 @@ def _to_canon(self, t):
return ProductType([to_canon(x) for x in t.__args__])

elif t.__origin__ is typing.Union:
return SumType([to_canon(x) for x in t.__args__])
res = [to_canon(x) for x in t.__args__]
return SumType(res)
elif t.__origin__ is abc.Callable or t is typing.Callable:
# return Callable[ProductType(to_canon(x) for x in t.__args__)]
return Callable # TODO
Expand Down
6 changes: 3 additions & 3 deletions runtype/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def is_subtype(t1, t2):
"""Test if t1 is a subtype of t2
"""

t1 = type_caster.to_canon(t1)
t2 = type_caster.to_canon(t2)
return t1 <= t2
ct1 = type_caster.to_canon(t1)
ct2 = type_caster.to_canon(t2)
return ct1 <= ct2


def isa(obj, t):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ def test_typing_extensions(self):
assert is_subtype(int, a)
assert isa(1, a)

@unittest.skipIf(not hasattr(typing, 'Literal'), "Literals not supported in this Python version")
def test_literal_comparison(self):
t1 = typing.Literal[1,2]
t2 = Union[typing.Literal[1], typing.Literal[2]]

assert is_subtype(t1, t2)
assert is_subtype(t2, t1)

assert is_subtype(Optional[typing.Literal[1]], typing.Literal[None, 1])
assert is_subtype(typing.Literal[None, 1], Optional[typing.Literal[1]])
assert is_subtype(typing.Literal[1,2,3], int)
assert is_subtype(typing.Literal["a","b"], str)
assert is_subtype(Tuple[typing.Literal[1,2,3], str], Tuple[int, str])

if sys.version_info >= (3, 9):
# the following fails for Python 3.8, because Literal[1] == Literal[True]
# and our caching swaps between them.
assert is_subtype(typing.Literal[True], bool)


class TestDispatch(TestCase):
Expand Down