Skip to content

Commit

Permalink
[red-knot] Narrowing for type(x) is C checks (#14432)
Browse files Browse the repository at this point in the history
## Summary

Add type narrowing for `type(x) is C` conditions (and `else` clauses of
`type(x) is not C` conditionals):

```py
if type(x) is A:
    reveal_type(x)  # revealed: A
else:
    reveal_type(x)  # revealed: A | B
```

closes: #14431, part of: #13694

## Test Plan

New Markdown-based tests.
  • Loading branch information
sharkdp authored Nov 18, 2024
1 parent 3642381 commit d8538d8
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 34 deletions.
152 changes: 152 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/narrow/type.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Narrowing for checks involving `type(x)`

## `type(x) is C`

```py
class A: ...
class B: ...

def get_a_or_b() -> A | B:
return A()

x = get_a_or_b()

if type(x) is A:
reveal_type(x) # revealed: A
else:
# It would be wrong to infer `B` here. The type
# of `x` could be a subclass of `A`, so we need
# to infer the full union type:
reveal_type(x) # revealed: A | B
```

## `type(x) is not C`

```py
class A: ...
class B: ...

def get_a_or_b() -> A | B:
return A()

x = get_a_or_b()

if type(x) is not A:
# Same reasoning as above: no narrowing should occur here.
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: A
```

## `type(x) == C`, `type(x) != C`

No narrowing can occur for equality comparisons, since there might be a custom `__eq__`
implementation on the metaclass.

TODO: Narrowing might be possible in some cases where the classes themselves are `@final` or their
metaclass is `@final`.

```py
class IsEqualToEverything(type):
def __eq__(cls, other):
return True

class A(metaclass=IsEqualToEverything): ...
class B(metaclass=IsEqualToEverything): ...

def get_a_or_b() -> A | B:
return B()

x = get_a_or_b()

if type(x) == A:
reveal_type(x) # revealed: A | B

if type(x) != A:
reveal_type(x) # revealed: A | B
```

## No narrowing for custom `type` callable

```py
class A: ...
class B: ...

def type(x):
return int

def get_a_or_b() -> A | B:
return A()

x = get_a_or_b()

if type(x) is A:
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: A | B
```

## No narrowing for multiple arguments

No narrowing should occur if `type` is used to dynamically create a class:

```py
def get_str_or_int() -> str | int:
return "test"

x = get_str_or_int()

if type(x, (), {}) is str:
reveal_type(x) # revealed: str | int
else:
reveal_type(x) # revealed: str | int
```

## No narrowing for keyword arguments

`type` can't be used with a keyword argument:

```py
def get_str_or_int() -> str | int:
return "test"

x = get_str_or_int()

# TODO: we could issue a diagnostic here
if type(object=x) is str:
reveal_type(x) # revealed: str | int
```

## Narrowing if `type` is aliased

```py
class A: ...
class B: ...

alias_for_type = type

def get_a_or_b() -> A | B:
return A()

x = get_a_or_b()

if alias_for_type(x) is A:
reveal_type(x) # revealed: A
```

## Limitations

```py
class Base: ...
class Derived(Base): ...

def get_base() -> Base:
return Base()

x = get_base()

if type(x) is Base:
# Ideally, this could be narrower, but there is now way to
# express a constraint like `Base & ~ProperSubtypeOf[Base]`.
reveal_type(x) # revealed: Base
```
120 changes: 86 additions & 34 deletions crates/red_knot_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,26 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_))
}

let ast::ExprCompare {
range: _,
left,
ops,
comparators,
} = expr_compare;
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
// If none of the comparators are name expressions,
// we have no symbol to narrow down the type of.

// Performance optimization: early return if there are no potential narrowing targets.
if !is_narrowing_target_candidate(left)
&& comparators
.iter()
.all(|c| !is_narrowing_target_candidate(c))
{
return None;
}

if !is_positive && comparators.len() > 1 {
// We can't negate a constraint made by a multi-comparator expression, since we can't
// know which comparison part is the one being negated.
Expand All @@ -283,42 +292,85 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
let mut constraints = NarrowingConstraints::default();
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
if let ast::Expr::Name(ast::ExprName {
range: _,
id,
ctx: _,
}) = left
{
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));

match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
} else {
// Non-singletons cannot be safely narrowed using `is not`
let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope));

match left {
ast::Expr::Name(ast::ExprName {
range: _,
id,
ctx: _,
}) => {
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");

match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
} else {
// Non-singletons cannot be safely narrowed using `is not`
}
}
}
ast::CmpOp::Is => {
constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::NotEq => {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
ast::CmpOp::Is => {
constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::NotEq => {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
}
}
_ => {
// TODO other comparison types
}
}
_ => {
// TODO other comparison types
}
ast::Expr::Call(ast::ExprCall {
range: _,
func: callable,
arguments:
ast::Arguments {
args,
keywords,
range: _,
},
}) if rhs_ty.is_class_literal() && keywords.is_empty() => {
let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else {
continue;
};

let is_valid_constraint = if is_positive {
op == &ast::CmpOp::Is
} else {
op == &ast::CmpOp::IsNot
};

if !is_valid_constraint {
continue;
}

let callable_ty =
inference.expression_ty(callable.scoped_expression_id(self.db, scope));

if callable_ty
.into_class_literal()
.is_some_and(|c| c.class.is_known(self.db, KnownClass::Type))
{
let symbol = self
.symbols()
.symbol_id_by_name(id)
.expect("Should always have a symbol for every Name node");
constraints.insert(symbol, rhs_ty.to_instance(self.db));
}
}
_ => {}
}
}
Some(constraints)
Expand Down

0 comments on commit d8538d8

Please sign in to comment.