diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md new file mode 100644 index 0000000000000..f0857fe267f17 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_nested.md @@ -0,0 +1,13 @@ +# Narrowing for nested conditionals + +```py +def int_instance() -> int: ... + + +x = int_instance() + +if x != 1: + if x != 2: + if x != 3: + reveal_type(x) # revealed: int & ~Literal[1] & ~Literal[2] & ~Literal[3] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md new file mode 100644 index 0000000000000..3811d381f91e9 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_not_eq.md @@ -0,0 +1,58 @@ +# Narrowing for `!=` conditionals + +## `x != None` + +```py +x = None if flag else 1 + +if x != None: + reveal_type(x) # revealed: Literal[1] +``` + +## `!=` for other singleton types + +```py +x = True if flag else False + +if x != False: + reveal_type(x) # revealed: Literal[True] +``` + +## `x != y` where `y` is of literal type + +```py +x = 1 if flag else 2 + +if x != 1: + reveal_type(x) # revealed: Literal[2] +``` + +## `x != y` where `y` is a single-valued type + +```py +class A: ... + + +class B: ... + + +C = A if flag else B + +if C != A: + reveal_type(C) # revealed: Literal[B] +``` + +## `!=` for non-single-valued types + +Only single-valued types should narrow the type: + +```py +def int_instance() -> int: ... + + +x = int_instance() if flag else None +y = int_instance() + +if x != y: + reveal_type(x) # revealed: int | None +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index dd66bf173bc23..c93fc150a35b7 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -667,6 +667,55 @@ impl<'db> Type<'db> { } } + /// Return true if this type is non-empty and all inhabitants of this type compare equal. + pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool { + match self { + Type::None + | Type::Function(..) + | Type::Module(..) + | Type::Class(..) + | Type::IntLiteral(..) + | Type::BooleanLiteral(..) + | Type::StringLiteral(..) + | Type::BytesLiteral(..) => true, + + Type::Tuple(tuple) => tuple + .elements(db) + .iter() + .all(|elem| elem.is_single_valued(db)), + + Type::Instance(class_type) => match class_type.known(db) { + Some(KnownClass::NoneType) => true, + Some( + KnownClass::Bool + | KnownClass::Object + | KnownClass::Bytes + | KnownClass::Type + | KnownClass::Int + | KnownClass::Float + | KnownClass::Str + | KnownClass::List + | KnownClass::Tuple + | KnownClass::Set + | KnownClass::Dict + | KnownClass::GenericAlias + | KnownClass::ModuleType + | KnownClass::FunctionType, + ) => false, + None => false, + }, + + Type::Any + | Type::Never + | Type::Unknown + | Type::Unbound + | Type::Todo + | Type::Union(..) + | Type::Intersection(..) + | Type::LiteralString => false, + } + } + /// Resolve a member access of a type. /// /// For example, if `foo` is `Type::Instance()`, @@ -1965,6 +2014,31 @@ mod tests { assert!(from.into_type(&db).is_singleton()); } + #[test_case(Ty::None)] + #[test_case(Ty::BooleanLiteral(true))] + #[test_case(Ty::IntLiteral(1))] + #[test_case(Ty::StringLiteral("abc"))] + #[test_case(Ty::BytesLiteral("abc"))] + #[test_case(Ty::Tuple(vec![]))] + #[test_case(Ty::Tuple(vec![Ty::BooleanLiteral(true), Ty::IntLiteral(1)]))] + fn is_single_valued(from: Ty) { + let db = setup_db(); + + assert!(from.into_type(&db).is_single_valued(&db)); + } + + #[test_case(Ty::Never)] + #[test_case(Ty::Any)] + #[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))] + #[test_case(Ty::Tuple(vec![Ty::None, Ty::BuiltinInstance("int")]))] + #[test_case(Ty::BuiltinInstance("str"))] + #[test_case(Ty::LiteralString)] + fn is_not_single_valued(from: Ty) { + let db = setup_db(); + + assert!(!from.into_type(&db).is_single_valued(&db)); + } + #[test_case(Ty::Never)] #[test_case(Ty::IntLiteral(345))] #[test_case(Ty::BuiltinInstance("str"))] diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 7d9d58995320b..23086de56b726 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -169,6 +169,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { ast::CmpOp::Is => { self.constraints.insert(symbol, comp_ty); } + ast::CmpOp::NotEq => { + if comp_ty.is_single_valued(self.db) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(comp_ty) + .build(); + self.constraints.insert(symbol, ty); + } + } _ => { // TODO other comparison types }