Skip to content

Commit

Permalink
[red-knot] Support for not-equal narrowing (#13749)
Browse files Browse the repository at this point in the history
Add type narrowing for `!=` expression as stated in
#13694.

###  Test Plan

Add tests in new md format.

---------

Co-authored-by: David Peter <[email protected]>
  • Loading branch information
Lexxxzy and sharkdp authored Oct 21, 2024
1 parent e39110e commit 9d10279
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Narrowing for nested conditionals

```py
def int_instance() -> int: ...


x = int_instance()

if x != 1:
if x != 2:
if x != 3:
reveal_type(x) # revealed: int & ~Literal[1] & ~Literal[2] & ~Literal[3]
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Narrowing for `!=` conditionals

## `x != None`

```py
x = None if flag else 1

if x != None:
reveal_type(x) # revealed: Literal[1]
```

## `!=` for other singleton types

```py
x = True if flag else False

if x != False:
reveal_type(x) # revealed: Literal[True]
```

## `x != y` where `y` is of literal type

```py
x = 1 if flag else 2

if x != 1:
reveal_type(x) # revealed: Literal[2]
```

## `x != y` where `y` is a single-valued type

```py
class A: ...


class B: ...


C = A if flag else B

if C != A:
reveal_type(C) # revealed: Literal[B]
```

## `!=` for non-single-valued types

Only single-valued types should narrow the type:

```py
def int_instance() -> int: ...


x = int_instance() if flag else None
y = int_instance()

if x != y:
reveal_type(x) # revealed: int | None
```
74 changes: 74 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,55 @@ impl<'db> Type<'db> {
}
}

/// Return true if this type is non-empty and all inhabitants of this type compare equal.
pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool {
match self {
Type::None
| Type::Function(..)
| Type::Module(..)
| Type::Class(..)
| Type::IntLiteral(..)
| Type::BooleanLiteral(..)
| Type::StringLiteral(..)
| Type::BytesLiteral(..) => true,

Type::Tuple(tuple) => tuple
.elements(db)
.iter()
.all(|elem| elem.is_single_valued(db)),

Type::Instance(class_type) => match class_type.known(db) {
Some(KnownClass::NoneType) => true,
Some(
KnownClass::Bool
| KnownClass::Object
| KnownClass::Bytes
| KnownClass::Type
| KnownClass::Int
| KnownClass::Float
| KnownClass::Str
| KnownClass::List
| KnownClass::Tuple
| KnownClass::Set
| KnownClass::Dict
| KnownClass::GenericAlias
| KnownClass::ModuleType
| KnownClass::FunctionType,
) => false,
None => false,
},

Type::Any
| Type::Never
| Type::Unknown
| Type::Unbound
| Type::Todo
| Type::Union(..)
| Type::Intersection(..)
| Type::LiteralString => false,
}
}

/// Resolve a member access of a type.
///
/// For example, if `foo` is `Type::Instance(<Bar>)`,
Expand Down Expand Up @@ -1973,6 +2022,31 @@ mod tests {
assert!(from.into_type(&db).is_singleton());
}

#[test_case(Ty::None)]
#[test_case(Ty::BooleanLiteral(true))]
#[test_case(Ty::IntLiteral(1))]
#[test_case(Ty::StringLiteral("abc"))]
#[test_case(Ty::BytesLiteral("abc"))]
#[test_case(Ty::Tuple(vec![]))]
#[test_case(Ty::Tuple(vec![Ty::BooleanLiteral(true), Ty::IntLiteral(1)]))]
fn is_single_valued(from: Ty) {
let db = setup_db();

assert!(from.into_type(&db).is_single_valued(&db));
}

#[test_case(Ty::Never)]
#[test_case(Ty::Any)]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))]
#[test_case(Ty::Tuple(vec![Ty::None, Ty::BuiltinInstance("int")]))]
#[test_case(Ty::BuiltinInstance("str"))]
#[test_case(Ty::LiteralString)]
fn is_not_single_valued(from: Ty) {
let db = setup_db();

assert!(!from.into_type(&db).is_single_valued(&db));
}

#[test_case(Ty::Never)]
#[test_case(Ty::IntLiteral(345))]
#[test_case(Ty::BuiltinInstance("str"))]
Expand Down
8 changes: 8 additions & 0 deletions crates/red_knot_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
ast::CmpOp::Is => {
self.constraints.insert(symbol, comp_ty);
}
ast::CmpOp::NotEq => {
if comp_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(comp_ty)
.build();
self.constraints.insert(symbol, ty);
}
}
_ => {
// TODO other comparison types
}
Expand Down

0 comments on commit 9d10279

Please sign in to comment.