From 2b3f5936f145e1b63a4deb17175beb7a2bd68997 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 29 Dec 2022 11:10:27 -0300 Subject: [PATCH 1/2] Improve OneOf comparisons and its interaction with Union --- runtype/pytypes.py | 50 +++++++++++++++++++++++++++++++++++++------ runtype/validation.py | 6 +++--- tests/test_basic.py | 13 +++++++++++ 3 files changed, 59 insertions(+), 10 deletions(-) diff --git a/runtype/pytypes.py b/runtype/pytypes.py index c45c0fa..2a39404 100644 --- a/runtype/pytypes.py +++ b/runtype/pytypes.py @@ -85,6 +85,16 @@ def validate_instance(self, obj, sampler=None): class SumType(base_types.SumType, PythonType): + def __init__(self, types): + # Here we merge all the instances of OneOf into a single one (if necessary). + # The alternative is to turn all OneOf instances into SumTypes of single values. + # I chose this method due to intuition that it's faster for the common use-case. + one_ofs: List[OneOf] = [t for t in types if isinstance(t, OneOf)] + if len(one_ofs) > 1: + rest = [t for t in types if not isinstance(t, OneOf)] + types = rest + [OneOf([v for t in one_ofs for v in t.values])] + super().__init__(types) + def validate_instance(self, obj, sampler=None): if not any(t.test_instance(obj) for t in self.types): raise TypeMismatchError(obj, self) @@ -167,12 +177,28 @@ def validate_instance(self, obj, sampler=None): class OneOf(PythonType): + values: typing.Sequence + def __init__(self, values): self.values = values def __le__(self, other): if isinstance(other, OneOf): return set(self.values) <= set(other.values) + elif isinstance(other, PythonDataType): + try: + for v in self.values: + other.validate_instance(v) + except TypeMismatchError as e: + return False + return True + return NotImplemented + + def __ge__(self, other): + if isinstance(other, OneOf): + return set(self.values) >= set(other.values) + elif isinstance(other, PythonDataType): + return False return NotImplemented def validate_instance(self, obj, sampler=None): @@ -182,6 +208,10 @@ def validate_instance(self, obj, sampler=None): def __repr__(self): return 'Literal[%s]' % ', '.join(map(repr, self.values)) + def cast_from(self, obj): + if obj not in self.values: + raise TypeMismatchError(obj, self) + class GenericType(base_types.GenericType, PythonType): def __init__(self, base, item=Any): @@ -264,11 +294,6 @@ def cast_from(self, obj): Literal = OneOf -class _NoneType(PythonDataType): - def cast_from(self, obj): - if obj is not None: - raise TypeMismatchError(obj, self) - class _Number(PythonDataType): def __call__(self, min=None, max=None): predicates = [] @@ -312,10 +337,20 @@ def cast_from(self, obj): return super().cast_from(obj) +class _NoneType(OneOf): + def __init__(self): + super().__init__([None]) + def cast_from(self, obj): + assert self.values == [None] + if obj is not None: + raise TypeMismatchError(obj, self) + return None + + String = _String(str) Int = _Int(int) Float = _Float(float) -NoneType = _NoneType(type(None)) +NoneType = _NoneType() DateTime = _DateTime(datetime) @@ -421,7 +456,8 @@ def _to_canon(self, t): return ProductType([to_canon(x) for x in t.__args__]) elif t.__origin__ is typing.Union: - return SumType([to_canon(x) for x in t.__args__]) + res = [to_canon(x) for x in t.__args__] + return SumType(res) elif t.__origin__ is abc.Callable or t is typing.Callable: # return Callable[ProductType(to_canon(x) for x in t.__args__)] return Callable # TODO diff --git a/runtype/validation.py b/runtype/validation.py index 9828b61..5c95d28 100644 --- a/runtype/validation.py +++ b/runtype/validation.py @@ -20,9 +20,9 @@ def is_subtype(t1, t2): """Test if t1 is a subtype of t2 """ - t1 = type_caster.to_canon(t1) - t2 = type_caster.to_canon(t2) - return t1 <= t2 + ct1 = type_caster.to_canon(t1) + ct2 = type_caster.to_canon(t2) + return ct1 <= ct2 def isa(obj, t): diff --git a/tests/test_basic.py b/tests/test_basic.py index d8f31b3..adb7874 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -148,6 +148,19 @@ def test_typing_extensions(self): assert is_subtype(int, a) assert isa(1, a) + def test_literal_comparison(self): + t1 = typing.Literal[1,2] + t2 = Union[typing.Literal[1], typing.Literal[2]] + + assert is_subtype(t1, t2) + assert is_subtype(t2, t1) + + assert is_subtype(Optional[typing.Literal[1]], typing.Literal[None, 1]) + assert is_subtype(typing.Literal[None, 1], Optional[typing.Literal[1]]) + assert is_subtype(typing.Literal[True], bool) + assert is_subtype(typing.Literal[1,2,3], int) + assert is_subtype(typing.Literal["a","b"], str) + assert is_subtype(Tuple[typing.Literal[1,2,3], str], Tuple[int, str]) class TestDispatch(TestCase): From ffdde1f39bbfdcd6c6fa1bd41d6a73ac1e200273 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Thu, 29 Dec 2022 11:26:09 -0300 Subject: [PATCH 2/2] Tests: Skip test if no Literal (<3.8). For 3.8, avoid testing Literal[True] bug --- tests/test_basic.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index adb7874..cc0e54d 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -148,6 +148,7 @@ def test_typing_extensions(self): assert is_subtype(int, a) assert isa(1, a) + @unittest.skipIf(not hasattr(typing, 'Literal'), "Literals not supported in this Python version") def test_literal_comparison(self): t1 = typing.Literal[1,2] t2 = Union[typing.Literal[1], typing.Literal[2]] @@ -157,11 +158,15 @@ def test_literal_comparison(self): assert is_subtype(Optional[typing.Literal[1]], typing.Literal[None, 1]) assert is_subtype(typing.Literal[None, 1], Optional[typing.Literal[1]]) - assert is_subtype(typing.Literal[True], bool) assert is_subtype(typing.Literal[1,2,3], int) assert is_subtype(typing.Literal["a","b"], str) assert is_subtype(Tuple[typing.Literal[1,2,3], str], Tuple[int, str]) + if sys.version_info >= (3, 9): + # the following fails for Python 3.8, because Literal[1] == Literal[True] + # and our caching swaps between them. + assert is_subtype(typing.Literal[True], bool) + class TestDispatch(TestCase): def setUp(self):