From d8538d8c9816a4d09af6089e54c5232bae5998fc Mon Sep 17 00:00:00 2001 From: David Peter Date: Mon, 18 Nov 2024 16:21:46 +0100 Subject: [PATCH] [red-knot] Narrowing for `type(x) is C` checks (#14432) ## Summary Add type narrowing for `type(x) is C` conditions (and `else` clauses of `type(x) is not C` conditionals): ```py if type(x) is A: reveal_type(x) # revealed: A else: reveal_type(x) # revealed: A | B ``` closes: #14431, part of: #13694 ## Test Plan New Markdown-based tests. --- .../resources/mdtest/narrow/type.md | 152 ++++++++++++++++++ .../src/types/narrow.rs | 120 ++++++++++---- 2 files changed, 238 insertions(+), 34 deletions(-) create mode 100644 crates/red_knot_python_semantic/resources/mdtest/narrow/type.md diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md new file mode 100644 index 0000000000000..01d2646198f80 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md @@ -0,0 +1,152 @@ +# Narrowing for checks involving `type(x)` + +## `type(x) is C` + +```py +class A: ... +class B: ... + +def get_a_or_b() -> A | B: + return A() + +x = get_a_or_b() + +if type(x) is A: + reveal_type(x) # revealed: A +else: + # It would be wrong to infer `B` here. The type + # of `x` could be a subclass of `A`, so we need + # to infer the full union type: + reveal_type(x) # revealed: A | B +``` + +## `type(x) is not C` + +```py +class A: ... +class B: ... + +def get_a_or_b() -> A | B: + return A() + +x = get_a_or_b() + +if type(x) is not A: + # Same reasoning as above: no narrowing should occur here. + reveal_type(x) # revealed: A | B +else: + reveal_type(x) # revealed: A +``` + +## `type(x) == C`, `type(x) != C` + +No narrowing can occur for equality comparisons, since there might be a custom `__eq__` +implementation on the metaclass. + +TODO: Narrowing might be possible in some cases where the classes themselves are `@final` or their +metaclass is `@final`. + +```py +class IsEqualToEverything(type): + def __eq__(cls, other): + return True + +class A(metaclass=IsEqualToEverything): ... +class B(metaclass=IsEqualToEverything): ... + +def get_a_or_b() -> A | B: + return B() + +x = get_a_or_b() + +if type(x) == A: + reveal_type(x) # revealed: A | B + +if type(x) != A: + reveal_type(x) # revealed: A | B +``` + +## No narrowing for custom `type` callable + +```py +class A: ... +class B: ... + +def type(x): + return int + +def get_a_or_b() -> A | B: + return A() + +x = get_a_or_b() + +if type(x) is A: + reveal_type(x) # revealed: A | B +else: + reveal_type(x) # revealed: A | B +``` + +## No narrowing for multiple arguments + +No narrowing should occur if `type` is used to dynamically create a class: + +```py +def get_str_or_int() -> str | int: + return "test" + +x = get_str_or_int() + +if type(x, (), {}) is str: + reveal_type(x) # revealed: str | int +else: + reveal_type(x) # revealed: str | int +``` + +## No narrowing for keyword arguments + +`type` can't be used with a keyword argument: + +```py +def get_str_or_int() -> str | int: + return "test" + +x = get_str_or_int() + +# TODO: we could issue a diagnostic here +if type(object=x) is str: + reveal_type(x) # revealed: str | int +``` + +## Narrowing if `type` is aliased + +```py +class A: ... +class B: ... + +alias_for_type = type + +def get_a_or_b() -> A | B: + return A() + +x = get_a_or_b() + +if alias_for_type(x) is A: + reveal_type(x) # revealed: A +``` + +## Limitations + +```py +class Base: ... +class Derived(Base): ... + +def get_base() -> Base: + return Base() + +x = get_base() + +if type(x) is Base: + # Ideally, this could be narrower, but there is now way to + # express a constraint like `Base & ~ProperSubtypeOf[Base]`. + reveal_type(x) # revealed: Base +``` diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 94efa3c4651d9..7ce84cb7fb553 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -257,17 +257,26 @@ impl<'db> NarrowingConstraintsBuilder<'db> { expression: Expression<'db>, is_positive: bool, ) -> Option> { + fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { + matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_)) + } + let ast::ExprCompare { range: _, left, ops, comparators, } = expr_compare; - if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) { - // If none of the comparators are name expressions, - // we have no symbol to narrow down the type of. + + // Performance optimization: early return if there are no potential narrowing targets. + if !is_narrowing_target_candidate(left) + && comparators + .iter() + .all(|c| !is_narrowing_target_candidate(c)) + { return None; } + if !is_positive && comparators.len() > 1 { // We can't negate a constraint made by a multi-comparator expression, since we can't // know which comparison part is the one being negated. @@ -283,42 +292,85 @@ impl<'db> NarrowingConstraintsBuilder<'db> { .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); let mut constraints = NarrowingConstraints::default(); for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { - if let ast::Expr::Name(ast::ExprName { - range: _, - id, - ctx: _, - }) = left - { - // SAFETY: we should always have a symbol for every Name node. - let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope)); - - match if is_positive { *op } else { op.negate() } { - ast::CmpOp::IsNot => { - if rhs_ty.is_singleton(self.db) { - let ty = IntersectionBuilder::new(self.db) - .add_negative(rhs_ty) - .build(); - constraints.insert(symbol, ty); - } else { - // Non-singletons cannot be safely narrowed using `is not` + let rhs_ty = inference.expression_ty(right.scoped_expression_id(self.db, scope)); + + match left { + ast::Expr::Name(ast::ExprName { + range: _, + id, + ctx: _, + }) => { + let symbol = self + .symbols() + .symbol_id_by_name(id) + .expect("Should always have a symbol for every Name node"); + + match if is_positive { *op } else { op.negate() } { + ast::CmpOp::IsNot => { + if rhs_ty.is_singleton(self.db) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(rhs_ty) + .build(); + constraints.insert(symbol, ty); + } else { + // Non-singletons cannot be safely narrowed using `is not` + } } - } - ast::CmpOp::Is => { - constraints.insert(symbol, rhs_ty); - } - ast::CmpOp::NotEq => { - if rhs_ty.is_single_valued(self.db) { - let ty = IntersectionBuilder::new(self.db) - .add_negative(rhs_ty) - .build(); - constraints.insert(symbol, ty); + ast::CmpOp::Is => { + constraints.insert(symbol, rhs_ty); + } + ast::CmpOp::NotEq => { + if rhs_ty.is_single_valued(self.db) { + let ty = IntersectionBuilder::new(self.db) + .add_negative(rhs_ty) + .build(); + constraints.insert(symbol, ty); + } + } + _ => { + // TODO other comparison types } } - _ => { - // TODO other comparison types + } + ast::Expr::Call(ast::ExprCall { + range: _, + func: callable, + arguments: + ast::Arguments { + args, + keywords, + range: _, + }, + }) if rhs_ty.is_class_literal() && keywords.is_empty() => { + let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else { + continue; + }; + + let is_valid_constraint = if is_positive { + op == &ast::CmpOp::Is + } else { + op == &ast::CmpOp::IsNot + }; + + if !is_valid_constraint { + continue; + } + + let callable_ty = + inference.expression_ty(callable.scoped_expression_id(self.db, scope)); + + if callable_ty + .into_class_literal() + .is_some_and(|c| c.class.is_known(self.db, KnownClass::Type)) + { + let symbol = self + .symbols() + .symbol_id_by_name(id) + .expect("Should always have a symbol for every Name node"); + constraints.insert(symbol, rhs_ty.to_instance(self.db)); } } + _ => {} } } Some(constraints)