Skip to content

Commit

Permalink
Merge pull request #22 from erezsh/literal_cmp
Browse files Browse the repository at this point in the history
Improve OneOf comparisons and its interaction with Union
  • Loading branch information
erezsh authored Dec 29, 2022
2 parents 532bb2e + ffdde1f commit 349db61
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 10 deletions.
50 changes: 43 additions & 7 deletions runtype/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ def validate_instance(self, obj, sampler=None):


class SumType(base_types.SumType, PythonType):
def __init__(self, types):
# Here we merge all the instances of OneOf into a single one (if necessary).
# The alternative is to turn all OneOf instances into SumTypes of single values.
# I chose this method due to intuition that it's faster for the common use-case.
one_ofs: List[OneOf] = [t for t in types if isinstance(t, OneOf)]
if len(one_ofs) > 1:
rest = [t for t in types if not isinstance(t, OneOf)]
types = rest + [OneOf([v for t in one_ofs for v in t.values])]
super().__init__(types)

def validate_instance(self, obj, sampler=None):
if not any(t.test_instance(obj) for t in self.types):
raise TypeMismatchError(obj, self)
Expand Down Expand Up @@ -167,12 +177,28 @@ def validate_instance(self, obj, sampler=None):


class OneOf(PythonType):
values: typing.Sequence

def __init__(self, values):
self.values = values

def __le__(self, other):
if isinstance(other, OneOf):
return set(self.values) <= set(other.values)
elif isinstance(other, PythonDataType):
try:
for v in self.values:
other.validate_instance(v)
except TypeMismatchError as e:
return False
return True
return NotImplemented

def __ge__(self, other):
if isinstance(other, OneOf):
return set(self.values) >= set(other.values)
elif isinstance(other, PythonDataType):
return False
return NotImplemented

def validate_instance(self, obj, sampler=None):
Expand All @@ -182,6 +208,10 @@ def validate_instance(self, obj, sampler=None):
def __repr__(self):
return 'Literal[%s]' % ', '.join(map(repr, self.values))

def cast_from(self, obj):
if obj not in self.values:
raise TypeMismatchError(obj, self)


class GenericType(base_types.GenericType, PythonType):
def __init__(self, base, item=Any):
Expand Down Expand Up @@ -264,11 +294,6 @@ def cast_from(self, obj):
Literal = OneOf


class _NoneType(PythonDataType):
def cast_from(self, obj):
if obj is not None:
raise TypeMismatchError(obj, self)

class _Number(PythonDataType):
def __call__(self, min=None, max=None):
predicates = []
Expand Down Expand Up @@ -312,10 +337,20 @@ def cast_from(self, obj):
return super().cast_from(obj)


class _NoneType(OneOf):
def __init__(self):
super().__init__([None])
def cast_from(self, obj):
assert self.values == [None]
if obj is not None:
raise TypeMismatchError(obj, self)
return None


String = _String(str)
Int = _Int(int)
Float = _Float(float)
NoneType = _NoneType(type(None))
NoneType = _NoneType()
DateTime = _DateTime(datetime)


Expand Down Expand Up @@ -421,7 +456,8 @@ def _to_canon(self, t):
return ProductType([to_canon(x) for x in t.__args__])

elif t.__origin__ is typing.Union:
return SumType([to_canon(x) for x in t.__args__])
res = [to_canon(x) for x in t.__args__]
return SumType(res)
elif t.__origin__ is abc.Callable or t is typing.Callable:
# return Callable[ProductType(to_canon(x) for x in t.__args__)]
return Callable # TODO
Expand Down
6 changes: 3 additions & 3 deletions runtype/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def is_subtype(t1, t2):
"""Test if t1 is a subtype of t2
"""

t1 = type_caster.to_canon(t1)
t2 = type_caster.to_canon(t2)
return t1 <= t2
ct1 = type_caster.to_canon(t1)
ct2 = type_caster.to_canon(t2)
return ct1 <= ct2


def isa(obj, t):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ def test_typing_extensions(self):
assert is_subtype(int, a)
assert isa(1, a)

@unittest.skipIf(not hasattr(typing, 'Literal'), "Literals not supported in this Python version")
def test_literal_comparison(self):
t1 = typing.Literal[1,2]
t2 = Union[typing.Literal[1], typing.Literal[2]]

assert is_subtype(t1, t2)
assert is_subtype(t2, t1)

assert is_subtype(Optional[typing.Literal[1]], typing.Literal[None, 1])
assert is_subtype(typing.Literal[None, 1], Optional[typing.Literal[1]])
assert is_subtype(typing.Literal[1,2,3], int)
assert is_subtype(typing.Literal["a","b"], str)
assert is_subtype(Tuple[typing.Literal[1,2,3], str], Tuple[int, str])

if sys.version_info >= (3, 9):
# the following fails for Python 3.8, because Literal[1] == Literal[True]
# and our caching swaps between them.
assert is_subtype(typing.Literal[True], bool)


class TestDispatch(TestCase):
Expand Down

0 comments on commit 349db61

Please sign in to comment.