diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 60fccc7e357c..1e6921814b20 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -20,6 +20,7 @@ COVARIANT, Decorator, FuncBase, + FuncDef, OverloadedFuncDef, TypeInfo, Var, @@ -405,8 +406,25 @@ def visit_none_type(self, left: NoneType) -> bool: members = self.right.type.protocol_members # None is compatible with Hashable (and other similar protocols). This is # slightly sloppy since we don't check the signature of "__hash__". - # None is also compatible with `SupportsStr` protocol. - return not members or all(member in ("__hash__", "__str__") for member in members) + # None is also compatible with `SupportsStr` and `SupportsBool` protocols. + if self.right.type.defn.info and "__bool__" in self.right.type.defn.info.names: + bool_method = self.right.type.defn.info.names["__bool__"] + assert bool_method.node is not None + assert isinstance(bool_method.node, FuncDef) + assert isinstance(bool_method.node.type, CallableType) + assert bool_method.node.type.items is not None + bool_method_types_info = bool_method.node.type.items[0] + # None should probably be incompatible with Literal[True] + if ( + isinstance( + get_proper_type(bool_method.node.type.items[0].ret_type), LiteralType + ) + and bool_method_types_info.ret_type.can_be_true + ): + return False + return not members or all( + member in ("__hash__", "__str__", "__bool__") for member in members + ) return False else: return True diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index dba01be50fee..4a44bdfeff98 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -2857,7 +2857,7 @@ c1: SupportsClassGetItem = C() [case testNoneVsProtocol] # mypy: strict-optional -from typing_extensions import Protocol +from typing_extensions import Literal, Protocol class MyHashable(Protocol): def __hash__(self) -> int: ... @@ -2890,6 +2890,22 @@ class SupportsStr(Protocol): def ss(s: SupportsStr) -> None: pass ss(None) +class SupportsBool(Protocol): + def __bool__(self) -> bool: ... + +class SupportsBoolLiteralTrue(Protocol): + def __bool__(self) -> Literal[True]: ... + +class SupportsBoolLiteralFalse(Protocol): + def __bool__(self) -> Literal[False]: ... + +def sb(s: SupportsBool) -> None: pass +sb(None) +def sblt(s: SupportsBoolLiteralTrue) -> None: pass +sblt(None) # E: Argument 1 to "sblt" has incompatible type "None"; expected "SupportsBoolLiteralTrue" +def sblf(s: SupportsBoolLiteralFalse) -> None: pass +sblf(None) + class HashableStr(Protocol): def __str__(self) -> str: ... def __hash__(self) -> int: ...