diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 2172361ea2f0..78cc74effe44 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -7,7 +7,7 @@ Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarLikeType, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType, ParamSpecType, - ENUM_REMOVED_PROPS + SelfType, ENUM_REMOVED_PROPS, ) from mypy.nodes import ( TypeInfo, FuncBase, Var, FuncDef, SymbolNode, SymbolTable, Context, @@ -145,6 +145,9 @@ def _analyze_member_access(name: str, typ = get_proper_type(typ) if isinstance(typ, Instance): return analyze_instance_member_access(name, typ, mx, override_info) + elif isinstance(typ, SelfType): + mx.self_type = typ.instance + return analyze_instance_member_access(name, typ.instance, mx, override_info) elif isinstance(typ, AnyType): # The base object has dynamic type. return AnyType(TypeOfAny.from_another_any, source_any=typ) @@ -495,6 +498,9 @@ def analyze_descriptor_access(descriptor_type: Type, descriptor_type = get_proper_type(descriptor_type) if isinstance(descriptor_type, UnionType): + for idx, item in enumerate(descriptor_type.items): + if isinstance(get_proper_type(item), SelfType): + descriptor_type.items[idx] = instance_type # Map the access over union types return make_simplified_union([ analyze_descriptor_access(typ, mx) @@ -502,6 +508,8 @@ def analyze_descriptor_access(descriptor_type: Type, ]) elif not isinstance(descriptor_type, Instance): return descriptor_type + elif isinstance(descriptor_type, SelfType): + return instance_type if not descriptor_type.type.has_readable_member('__get__'): return descriptor_type diff --git a/mypy/constraints.py b/mypy/constraints.py index 2f071e13a002..98051a3fec48 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -7,7 +7,7 @@ CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, - ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any, + ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any, SelfType, UnpackType, callable_with_ellipsis, Parameters, TUPLE_LIKE_INSTANCE_NAMES, TypeVarTupleType, ) from mypy.maptype import map_instance_to_supertype @@ -399,6 +399,9 @@ def visit_type_var(self, template: TypeVarType) -> List[Constraint]: assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor" " (should have been handled in infer_constraints)") + def visit_self_type(self, template: SelfType) -> List[Constraint]: + return self.visit_instance(template.instance) + def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]: # Can't infer ParamSpecs from component values (only via Callable[P, T]). return [] diff --git a/mypy/copytype.py b/mypy/copytype.py index 85d7d531c5a3..0592c3b3e862 100644 --- a/mypy/copytype.py +++ b/mypy/copytype.py @@ -4,7 +4,7 @@ ProperType, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, ParamSpecType, PartialType, CallableType, TupleType, TypedDictType, LiteralType, UnionType, Overloaded, TypeType, TypeAliasType, UnpackType, Parameters, - TypeVarTupleType + TypeVarTupleType, SelfType ) from mypy.type_visitor import TypeVisitor @@ -75,6 +75,9 @@ def visit_unpack_type(self, t: UnpackType) -> ProperType: dup = UnpackType(t.type) return self.copy_common(t, dup) + def visit_self_type(self, t: SelfType) -> ProperType: + return self.copy_common(t, SelfType(t.instance, t.fullname, t.line, t.column)) + def visit_partial_type(self, t: PartialType) -> ProperType: return self.copy_common(t, PartialType(t.type, t.var, t.value_type)) diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 21ca5771b32e..b54de98331e6 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -5,7 +5,7 @@ CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, Parameters, UnpackType, - TypeVarTupleType + TypeVarTupleType, SelfType, ) from mypy.nodes import ARG_STAR, ARG_STAR2 @@ -57,6 +57,9 @@ def visit_instance(self, t: Instance) -> ProperType: def visit_type_var(self, t: TypeVarType) -> ProperType: return AnyType(TypeOfAny.special_form) + def visit_self_type(self, t: SelfType) -> ProperType: + return self.visit_instance(t.instance) + def visit_param_spec(self, t: ParamSpecType) -> ProperType: return AnyType(TypeOfAny.special_form) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index ce43aeaeb6e5..5ba0e986d380 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -6,7 +6,7 @@ ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType, TypeAliasType, ParamSpecType, TypeVarLikeType, Parameters, ParamSpecFlavor, - UnpackType, TypeVarTupleType + UnpackType, TypeVarTupleType, SelfType, ) @@ -100,6 +100,9 @@ def visit_type_var(self, t: TypeVarType) -> Type: else: return repl + def visit_self_type(self, t: SelfType) -> Type: + return t + def visit_param_spec(self, t: ParamSpecType) -> Type: repl = get_proper_type(self.variables.get(t.id, t)) if isinstance(repl, Instance): diff --git a/mypy/fixup.py b/mypy/fixup.py index 85c1df079a5a..ac3b70422eab 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -11,7 +11,7 @@ CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, ParamSpecType, - Parameters, UnpackType, TypeVarTupleType + Parameters, UnpackType, TypeVarTupleType, SelfType, ) from mypy.visitor import NodeVisitor from mypy.lookup import lookup_fully_qualified @@ -246,6 +246,9 @@ def visit_type_var(self, tvt: TypeVarType) -> None: if tvt.upper_bound is not None: tvt.upper_bound.accept(self) + def visit_self_type(self, t: SelfType) -> None: + return t.instance.accept(self) + def visit_param_spec(self, p: ParamSpecType) -> None: p.upper_bound.accept(self) diff --git a/mypy/indirection.py b/mypy/indirection.py index 56c1f97928f2..bf897e70a5e2 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -64,6 +64,9 @@ def visit_deleted_type(self, t: types.DeletedType) -> Set[str]: def visit_type_var(self, t: types.TypeVarType) -> Set[str]: return self._visit(t.values) | self._visit(t.upper_bound) + def visit_self_type(self, t: types.SelfType) -> Set[str]: + return set() + def visit_param_spec(self, t: types.ParamSpecType) -> Set[str]: return set() diff --git a/mypy/join.py b/mypy/join.py index 70c250a7703c..5e2893712107 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -8,7 +8,7 @@ TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + UnpackType, TypeVarTupleType, SelfType, ) from mypy.maptype import map_instance_to_supertype from mypy.subtypes import ( @@ -256,6 +256,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: else: return self.default(self.s) + def visit_self_type(self, t: SelfType) -> ProperType: + return self.join(self.s, t.instance) + def visit_param_spec(self, t: ParamSpecType) -> ProperType: if self.s == t: return t diff --git a/mypy/meet.py b/mypy/meet.py index ebaf0f675ef1..c608f3d0533b 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -6,7 +6,7 @@ TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType, - ParamSpecType, Parameters, UnpackType, TypeVarTupleType, TypeVarLikeType + ParamSpecType, Parameters, UnpackType, TypeVarTupleType, SelfType, TypeVarLikeType ) from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype from mypy.erasetype import erase_type @@ -537,6 +537,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType: else: return self.default(self.s) + def visit_self_type(self, t: SelfType) -> ProperType: + return self.meet(self.s, t.instance) + def visit_param_spec(self, t: ParamSpecType) -> ProperType: if self.s == t: return self.s diff --git a/mypy/mixedtraverser.py b/mypy/mixedtraverser.py index c14648cdf654..e12931921e17 100644 --- a/mypy/mixedtraverser.py +++ b/mypy/mixedtraverser.py @@ -5,7 +5,7 @@ CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr, PromoteExpr, NewTypeExpr ) -from mypy.types import Type +from mypy.types import Type, SelfType from mypy.traverser import TraverserVisitor from mypy.typetraverser import TypeTraverserVisitor @@ -41,6 +41,10 @@ def visit_type_var_expr(self, o: TypeVarExpr) -> None: for value in o.values: value.accept(self) + def visit_self_type(self, o: SelfType) -> None: + super().visit_self_type(o) + o.instance.accept(self) + def visit_typeddict_expr(self, o: TypedDictExpr) -> None: super().visit_typeddict_expr(o) self.visit_optional_type(o.info.typeddict_type) diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 4fbc9bfc4801..caada44c3aaa 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -5,7 +5,7 @@ UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + UnpackType, TypeVarTupleType, SelfType, ) from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal @@ -114,6 +114,9 @@ def visit_type_var(self, left: TypeVarType) -> bool: return (isinstance(self.right, TypeVarType) and left.id == self.right.id) + def visit_self_type(self, left: SelfType) -> bool: + return isinstance(self.right, SelfType) and self.right.instance == left.instance + def visit_param_spec(self, left: ParamSpecType) -> bool: # Ignore upper bound since it's derived from flavor. return (isinstance(self.right, ParamSpecType) and diff --git a/mypy/semanal.py b/mypy/semanal.py index e00913a8cde4..6b11ae890c95 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -49,88 +49,88 @@ """ from contextlib import contextmanager +from typing import (Any, Callable, Dict, Iterable, Iterator, List, Optional, + Set, Tuple, TypeVar, Union, cast) -from typing import ( - Any, List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable, Iterator, Iterable -) -from typing_extensions import Final, TypeAlias as _TypeAlias - -from mypy.nodes import ( - AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef, - ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue, - ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr, - IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, - RaiseStmt, AssertStmt, OperatorAssignmentStmt, WhileStmt, - ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, - GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, - SliceExpr, CastExpr, RevealExpr, TypeApplication, Context, SymbolTable, - SymbolTableNode, ListComprehension, GeneratorExpr, - LambdaExpr, MDEF, Decorator, SetExpr, TypeVarExpr, - StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, - ComparisonExpr, StarExpr, ArgKind, ARG_POS, ARG_NAMED, type_aliases, - YieldFromExpr, NamedTupleExpr, NonlocalDecl, SymbolNode, - SetComprehension, DictionaryComprehension, TypeAlias, TypeAliasExpr, - YieldExpr, ExecStmt, BackquoteExpr, ImportBase, AwaitExpr, - IntExpr, FloatExpr, UnicodeExpr, TempNode, OverloadPart, - PlaceholderNode, COVARIANT, CONTRAVARIANT, INVARIANT, - get_nongen_builtins, get_member_expr_fullname, REVEAL_TYPE, - REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_source_versions, - typing_extensions_aliases, - EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, - ParamSpecExpr, EllipsisExpr, TypeVarLikeExpr, implicit_module_attrs, - MatchStmt, FuncBase, TypeVarTupleExpr -) -from mypy.patterns import ( - AsPattern, OrPattern, ValuePattern, SequencePattern, - StarredPattern, MappingPattern, ClassPattern, -) -from mypy.tvar_scope import TypeVarLikeScope -from mypy.typevars import fill_typevars -from mypy.visitor import NodeVisitor -from mypy.errors import Errors, report_internal_error -from mypy.messages import ( - best_matches, MessageBuilder, pretty_seq, SUGGESTED_TEST_FIXTURES, TYPES_FOR_UNIMPORTED_HINTS -) +from typing_extensions import Final +from typing_extensions import TypeAlias as _TypeAlias + +from mypy import errorcodes as codes +from mypy import message_registry from mypy.errorcodes import ErrorCode -from mypy import message_registry, errorcodes as codes -from mypy.types import ( - NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType, - CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue, - TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType, - get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType, - PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES, - ASSERT_TYPE_NAMES, OVERLOAD_NAMES, is_named_instance, -) -from mypy.typeops import function_type, get_type_vars -from mypy.type_visitor import TypeQuery -from mypy.typeanal import ( - TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias, - TypeVarLikeQuery, TypeVarLikeList, remove_dups, has_any_from_unimported_type, - check_for_explicit_any, type_constructors, fix_instance_types -) -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError +from mypy.errors import Errors, report_internal_error +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type +from mypy.messages import (SUGGESTED_TEST_FIXTURES, TYPES_FOR_UNIMPORTED_HINTS, + MessageBuilder, best_matches, pretty_seq) +from mypy.mro import MroError, calculate_mro +from mypy.nodes import (ARG_NAMED, ARG_POS, CONTRAVARIANT, COVARIANT, GDEF, + INVARIANT, LDEF, MDEF, REVEAL_LOCALS, REVEAL_TYPE, + RUNTIME_PROTOCOL_DECOS, ArgKind, AssertStmt, + AssertTypeExpr, AssignmentExpr, AssignmentStmt, + AwaitExpr, BackquoteExpr, Block, BreakStmt, BytesExpr, + CallExpr, CastExpr, ClassDef, ComparisonExpr, + ConditionalExpr, Context, ContinueStmt, Decorator, + DelStmt, DictExpr, DictionaryComprehension, + EllipsisExpr, EnumCallExpr, ExecStmt, Expression, + ExpressionStmt, FakeExpression, FloatExpr, ForStmt, + FuncBase, FuncDef, FuncItem, GeneratorExpr, GlobalDecl, + IfStmt, Import, ImportAll, ImportBase, ImportFrom, + IndexExpr, IntExpr, LambdaExpr, ListComprehension, + ListExpr, Lvalue, MatchStmt, MemberExpr, MypyFile, + NamedTupleExpr, NameExpr, Node, NonlocalDecl, + OperatorAssignmentStmt, OpExpr, OverloadedFuncDef, + OverloadPart, ParamSpecExpr, PlaceholderNode, + PrintStmt, PromoteExpr, RaiseStmt, RefExpr, ReturnStmt, + RevealExpr, SetComprehension, SetExpr, SliceExpr, + StarExpr, Statement, StrExpr, SuperExpr, SymbolNode, + SymbolTable, SymbolTableNode, TempNode, TryStmt, + TupleExpr, TypeAlias, TypeAliasExpr, TypeApplication, + TypedDictExpr, TypeInfo, TypeVarExpr, TypeVarLikeExpr, + TypeVarTupleExpr, UnaryExpr, UnicodeExpr, Var, + WhileStmt, WithStmt, YieldExpr, YieldFromExpr, + get_member_expr_fullname, get_nongen_builtins, + implicit_module_attrs, is_final_node, type_aliases, + type_aliases_source_versions, + typing_extensions_aliases) from mypy.options import Options -from mypy.plugin import ( - Plugin, ClassDefContext, SemanticAnalyzerPluginInterface, - DynamicClassDefContext -) -from mypy.util import ( - correct_relative_import, unmangle, module_prefix, is_typeshed_file, unnamed_function, - is_dunder, -) +from mypy.patterns import (AsPattern, ClassPattern, MappingPattern, OrPattern, + SequencePattern, StarredPattern, ValuePattern) +from mypy.plugin import (ClassDefContext, DynamicClassDefContext, Plugin, + SemanticAnalyzerPluginInterface) +from mypy.reachability import (ALWAYS_FALSE, ALWAYS_TRUE, MYPY_FALSE, + MYPY_TRUE, infer_condition_value, + infer_reachability_of_if_statement, + infer_reachability_of_match_statement) from mypy.scope import Scope -from mypy.semanal_shared import ( - SemanticAnalyzerInterface, set_callable_name, calculate_tuple_fallback, PRIORITY_FALLBACKS -) -from mypy.semanal_namedtuple import NamedTupleAnalyzer -from mypy.semanal_typeddict import TypedDictAnalyzer from mypy.semanal_enum import EnumCallAnalyzer +from mypy.semanal_namedtuple import NamedTupleAnalyzer from mypy.semanal_newtype import NewTypeAnalyzer -from mypy.reachability import ( - infer_reachability_of_if_statement, infer_reachability_of_match_statement, - infer_condition_value, ALWAYS_FALSE, ALWAYS_TRUE, MYPY_TRUE, MYPY_FALSE -) -from mypy.mro import calculate_mro, MroError +from mypy.semanal_shared import (PRIORITY_FALLBACKS, SemanticAnalyzerInterface, + calculate_tuple_fallback, set_callable_name) +from mypy.semanal_typeddict import TypedDictAnalyzer +from mypy.tvar_scope import TypeVarLikeScope +from mypy.type_visitor import TypeQuery +from mypy.typeanal import (TypeAnalyser, TypeVarLikeList, TypeVarLikeQuery, + analyze_type_alias, check_for_explicit_any, + fix_instance_types, has_any_from_unimported_type, + no_subscript_builtin_alias, remove_dups, + type_constructors) +from mypy.typeops import function_type, get_type_vars +from mypy.types import (ASSERT_TYPE_NAMES, FINAL_DECORATOR_NAMES, + FINAL_TYPE_NAMES, NEVER_NAMES, OVERLOAD_NAMES, + PROTOCOL_NAMES, REVEAL_TYPE_NAMES, SELF_TYPE_NAMES, + TPDICT_NAMES, TYPE_ALIAS_NAMES, AnyType, CallableType, + FunctionLike, Instance, LiteralType, LiteralValue, + NoneType, Overloaded, Parameters, ParamSpecType, + PlaceholderType, ProperType, SelfType, StarType, + TupleType, Type, TypeAliasType, TypeOfAny, + TypeTranslator, TypeType, TypeVarLikeType, TypeVarType, + UnboundType, UnionType, get_proper_type, + get_proper_types, is_named_instance) +from mypy.typevars import fill_typevars +from mypy.util import (correct_relative_import, is_dunder, is_typeshed_file, + module_prefix, unmangle, unnamed_function) +from mypy.visitor import NodeVisitor T = TypeVar('T') @@ -697,19 +697,83 @@ def analyze_func_def(self, defn: FuncDef) -> None: def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None: """Check basic signature validity and tweak annotation of self/cls argument.""" # Only non-static methods are special. - functype = func.type if not func.is_static: if func.name in ['__init_subclass__', '__class_getitem__']: func.is_class = True if not func.arguments: self.fail('Method must have at least one argument', func) - elif isinstance(functype, CallableType): - self_type = get_proper_type(functype.arg_types[0]) + elif isinstance(func.type, CallableType): + self_type = get_proper_type(func.type.arg_types[0]) if isinstance(self_type, AnyType): leading_type: Type = fill_typevars(info) if func.is_class or func.name == '__new__': leading_type = self.class_type(leading_type) - func.type = replace_implicit_first_type(functype, leading_type) + if not self_type.type_of_any == TypeOfAny.explicit: + func.type = replace_implicit_first_type(func.type, leading_type) + + assert isinstance(func.type, CallableType) + leading_type = func.type.arg_types[0] + proper_leading_type = get_proper_type(leading_type) + if self.is_self_type(proper_leading_type): # method[[Self, ...], Self] case + proper_leading_type = func.type.arg_types[0] = self.named_type(info.fullname) + elif isinstance(proper_leading_type, UnboundType): + # classmethod[[type[Self], ...], Self] case + node = self.lookup(proper_leading_type.name, func) + if ( + node is not None + and node.fullname in {"typing.Type", "builtins.type"} + and proper_leading_type.args + and self.is_self_type(proper_leading_type.args[0]) + ): + proper_leading_type = func.type.arg_types[0] = get_proper_type( + self.class_type(self.named_type(info.fullname)) + ) + if not isinstance(proper_leading_type, (Instance, TypeType)): + return + if isinstance(proper_leading_type, TypeType): + self_type = proper_leading_type.item + else: + self_type = proper_leading_type + fullname: Optional[str] = None + # bind any SelfTypes + for idx, arg in enumerate(func.type.arg_types): + if self.is_self_type(arg): + if func.is_static: + self.fail( + "Self-type annotations of staticmethods are not supported, " + "please replace the type with {}".format(self_type.type.name), + func + ) + func.type.arg_types[idx] = self.named_type( + self_type.type.name + ) # we replace them here for them + continue + if fullname is None: + assert isinstance(arg, UnboundType) + table_node = self.lookup(arg.name, func) + assert isinstance(table_node, SymbolTableNode) and table_node.node + fullname = table_node.node.fullname + assert isinstance(self_type, Instance) + func.type.arg_types[idx] = SelfType(self_type, fullname=fullname) + + if self.is_self_type(func.type.ret_type): + if fullname is None: + assert isinstance(func.type.ret_type, UnboundType) + table_node = self.lookup_qualified( + func.type.ret_type.name, func.type.ret_type + ) + assert isinstance(table_node, SymbolTableNode) and table_node.node + fullname = table_node.node.fullname + if func.is_static: + self.fail( + "Self-type annotations of staticmethods are not supported, " + "please replace the type with {}".format(self_type.type.name), + func, + ) + func.type.ret_type = self.named_type(self_type.type.name) + return + assert isinstance(self_type, Instance) + func.type.ret_type = SelfType(self_type, fullname=fullname) def set_original_def(self, previous: Optional[Node], new: Union[FuncDef, Decorator]) -> bool: """If 'new' conditionally redefine 'previous', set 'previous' as original @@ -1609,6 +1673,19 @@ def configure_base_classes(self, return self.calculate_class_mro(defn, self.object_type) + # return? + # for base in info.mro: + # for name, type in base.names.items(): + # if isinstance(type, SelfType): # bind Self + # info.names[name] = SelfType(self.named_type(defn.fullname), type.fullname) + # elif isinstance(type, UnionType): + # info.names[name] = UnionType([ + # item + # if not isinstance(item, SelfType) + # else SelfType(self.named_type(defn.fullname), type.fullname) + # for item in type.items + # ]) + def configure_tuple_base_class(self, defn: ClassDef, base: TupleType, @@ -3458,6 +3535,14 @@ def is_final_type(self, typ: Optional[Type]) -> bool: return False return sym.node.fullname in FINAL_TYPE_NAMES + def is_self_type(self, typ: Optional[Type]) -> bool: + if not isinstance(typ, UnboundType): + return False + sym = self.lookup_qualified(typ.name, typ) + if not sym or not sym.node: + return False + return sym.node.fullname in SELF_TYPE_NAMES + def fail_invalid_classvar(self, context: Context) -> None: self.fail(message_registry.CLASS_VAR_OUTSIDE_OF_CLASS, context) @@ -3858,6 +3943,8 @@ def visit_name_expr(self, expr: NameExpr) -> None: def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None: """Bind name expression to a symbol table node.""" + if sym.node and sym.node.fullname in SELF_TYPE_NAMES and not self.is_class_scope(): + self.fail('{} is unbound'.format(expr.name), expr) if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym): self.fail('"{}" is a type variable and only valid in type ' 'context'.format(expr.name), expr) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 1f1c6b65f385..8f94c9240898 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -60,7 +60,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType, ParamSpecType, - Parameters, UnpackType, TypeVarTupleType, + Parameters, UnpackType, TypeVarTupleType, SelfType, ) from mypy.util import get_prefix @@ -333,6 +333,12 @@ def visit_parameters(self, typ: Parameters) -> SnapshotItem: tuple(encode_optional_str(name) for name in typ.arg_names), tuple(typ.arg_kinds)) + def visit_self_type(self, typ: SelfType) -> SnapshotItem: + return ('SelfType', + typ.fullname, + snapshot_type(typ.instance) + ) + def visit_callable_type(self, typ: CallableType) -> SnapshotItem: # FIX generics return ('CallableType', diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index be69b3c00d97..47c177a12d3f 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -60,7 +60,7 @@ TupleType, TypeType, TypedDictType, UnboundType, UninhabitedType, UnionType, Overloaded, TypeVarType, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, RawExpressionType, PartialType, PlaceholderType, TypeAliasType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + UnpackType, TypeVarTupleType, SelfType, ) from mypy.util import get_prefix, replace_object_state from mypy.typestate import TypeState @@ -427,6 +427,9 @@ def visit_parameters(self, typ: Parameters) -> None: for arg in typ.arg_types: arg.accept(self) + def visit_self_type(self, typ: SelfType) -> None: + typ.instance.accept(self) + def visit_typeddict_type(self, typ: TypedDictType) -> None: for value_type in typ.items.values(): value_type.accept(self) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index f339344e79b5..c5ee2fd5be78 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -100,7 +100,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, - TypeAliasType, ParamSpecType, Parameters, UnpackType, TypeVarTupleType, + TypeAliasType, ParamSpecType, Parameters, UnpackType, TypeVarTupleType, SelfType, ) from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.util import correct_relative_import @@ -982,6 +982,14 @@ def visit_parameters(self, typ: Parameters) -> List[str]: triggers.extend(self.get_type_triggers(arg)) return triggers + def visit_self_type(self, typ: SelfType) -> List[str]: + triggers = [] + if typ.fullname: + triggers.append(make_trigger(typ.fullname)) + if typ.instance: + triggers.extend(self.get_type_triggers(typ.instance)) + return triggers + def visit_typeddict_type(self, typ: TypedDictType) -> List[str]: triggers = [] for item in typ.items.values(): diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 8b7b3153ecaf..6050760ebf5d 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -8,7 +8,7 @@ Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType, ParamSpecType, - Parameters, UnpackType, TUPLE_LIKE_INSTANCE_NAMES, TypeVarTupleType, + Parameters, UnpackType, TUPLE_LIKE_INSTANCE_NAMES, TypeVarTupleType, SelfType, ) import mypy.applytype import mypy.constraints @@ -263,6 +263,8 @@ def visit_instance(self, left: Instance) -> bool: return False return True right = self.right + if isinstance(right, SelfType): + return self._is_subtype(left, right.instance) if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum: return self._is_subtype(left, mypy.typeops.tuple_fallback(right)) if isinstance(right, Instance): @@ -336,6 +338,9 @@ def visit_type_var(self, left: TypeVarType) -> bool: return True return self._is_subtype(left.upper_bound, self.right) + def visit_self_type(self, left: SelfType) -> bool: + return self._is_subtype(left.instance, self.right) + def visit_param_spec(self, left: ParamSpecType) -> bool: right = self.right if ( @@ -1471,6 +1476,9 @@ def visit_type_var(self, left: TypeVarType) -> bool: return True return self._is_proper_subtype(left.upper_bound, self.right) + def visit_self_type(self, t: SelfType) -> bool: + return self._is_proper_subtype(t.instance, self.right) + def visit_param_spec(self, left: ParamSpecType) -> bool: right = self.right if ( diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index ddcb78df8100..fd764daab3a3 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -71,6 +71,7 @@ 'check-newtype.test', 'check-class-namedtuple.test', 'check-selftype.test', + 'check-selftyping.test', 'check-python2.test', 'check-columns.test', 'check-functions.test', diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 79b4cb12d512..298d7b9dedac 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -13,7 +13,8 @@ from abc import abstractmethod from mypy.backports import OrderedDict -from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set, Sequence +from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional,\ + Set, Sequence from mypy_extensions import trait, mypyc_attr T = TypeVar('T') @@ -24,7 +25,7 @@ UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeType, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, PlaceholderType, TypeAliasType, ParamSpecType, UnpackType, TypeVarTupleType, - get_proper_type + get_proper_type, SelfType, ) @@ -65,6 +66,9 @@ def visit_type_var(self, t: TypeVarType) -> T: pass @abstractmethod + def visit_self_type(self, t: SelfType) -> T: + pass + def visit_param_spec(self, t: ParamSpecType) -> T: pass @@ -196,6 +200,9 @@ def visit_instance(self, t: Instance) -> Type: def visit_type_var(self, t: TypeVarType) -> Type: return t + def visit_self_type(self, t: SelfType) -> Type: + return t + def visit_param_spec(self, t: ParamSpecType) -> Type: return t @@ -320,6 +327,9 @@ def visit_deleted_type(self, t: DeletedType) -> T: def visit_type_var(self, t: TypeVarType) -> T: return self.query_types([t.upper_bound] + t.values) + def visit_self_type(self, t: SelfType) -> T: + return self.strategy([]) + def visit_param_spec(self, t: ParamSpecType) -> T: return self.strategy([]) @@ -392,3 +402,87 @@ def query_types(self, types: Iterable[Type]) -> T: self.seen_aliases.add(t) res.append(t.accept(self)) return self.strategy(res) + + +class SelfTypeVisitor(TypeVisitor[Any]): + def __init__(self, self_type: Type) -> None: + self.self_type = self_type + + def visit_unbound_type(self, t: UnboundType) -> None: + pass + + def visit_any(self, t: AnyType) -> None: + pass + + def visit_none_type(self, t: NoneType) -> None: + pass + + def visit_uninhabited_type(self, t: UninhabitedType) -> None: + pass + + def visit_erased_type(self, t: ErasedType) -> None: + pass + + def visit_deleted_type(self, t: DeletedType) -> None: + pass + + def visit_type_var(self, t: TypeVarType) -> None: + pass + + def visit_self_type(self, t: SelfType) -> None: + pass # should this raise? + + def visit_param_spec(self, t: ParamSpecType) -> None: + pass + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> T: + pass + + def visit_unpack_type(self, t: UnpackType) -> T: + pass + + def visit_parameters(self, t: Parameters) -> T: + pass + + def visit_instance(self, t: Instance) -> None: + t.args = tuple(self.replace_types(t.args)) + + def visit_callable_type(self, t: CallableType) -> None: + t.arg_types = self.replace_types(t.arg_types) + t.ret_type = self.replace_type(t.ret_type) + + def visit_overloaded(self, t: Overloaded) -> None: + for item in t.items: + item.accept(self) + + def visit_tuple_type(self, t: TupleType) -> None: + t.items = self.replace_types(t.items) + + def visit_typeddict_type(self, t: TypedDictType) -> None: + for key, value in zip(t.items, self.replace_types(t.items.values())): + t.items[key] = value + + def visit_literal_type(self, t: LiteralType) -> None: + pass + + def visit_union_type(self, t: UnionType) -> None: + t.items = self.replace_types(t.items) + + def visit_partial_type(self, t: PartialType) -> None: + pass + + def visit_type_type(self, t: TypeType) -> None: + t.item = get_proper_type(self.replace_type(t.item)) + + def visit_type_alias_type(self, t: TypeAliasType) -> None: + pass # TODO this is probably invalid + + def replace_types(self, types: Iterable[Type]) -> List[Type]: + return [self.replace_type(typ) for typ in types] + + def replace_type(self, typ: Type) -> Type: + if isinstance(typ, SelfType): # type: ignore + typ = self.self_type + else: + typ.accept(self) + return typ diff --git a/mypy/typeanal.py b/mypy/typeanal.py index bd0f684653b2..0c41764c4d81 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -17,8 +17,8 @@ Parameters, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, PlaceholderType, Overloaded, get_proper_type, TypeAliasType, RequiredType, TypeVarLikeType, ParamSpecType, ParamSpecFlavor, UnpackType, TypeVarTupleType, - callable_with_ellipsis, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, - LITERAL_TYPE_NAMES, ANNOTATED_TYPE_NAMES, + callable_with_ellipsis, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, SELF_TYPE_NAMES, + LITERAL_TYPE_NAMES, ANNOTATED_TYPE_NAMES, SelfType, ) from mypy.nodes import ( @@ -428,6 +428,17 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt self.fail("NotRequired[] must have exactly one type argument", t) return AnyType(TypeOfAny.from_error) return RequiredType(self.anal_type(t.args[0]), required=False) + elif fullname in SELF_TYPE_NAMES: + from mypy.semanal import SemanticAnalyzer # circular import + + if not isinstance(self.api, SemanticAnalyzer): + self.fail("Self is unbound", t) + return AnyType(TypeOfAny.from_error) + if not isinstance(self.api.type, TypeInfo): + self.fail("Self is not enclosed in a class", t) + return AnyType(TypeOfAny.from_error) + bound = self.named_type(self.api.type.fullname) + return SelfType(bound, fullname, line=t.line, column=t.column) elif self.anal_type_guard_arg(t, fullname) is not None: # In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args) return self.named_type('builtins.bool') @@ -646,6 +657,9 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: def visit_type_var(self, t: TypeVarType) -> Type: return t + def visit_self_type(self, t: SelfType) -> Type: + return t + def visit_param_spec(self, t: ParamSpecType) -> Type: return t diff --git a/mypy/typeops.py b/mypy/typeops.py index 835c8f0a7229..3f75c5b5603d 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,29 +5,29 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union -from typing_extensions import Type as TypingType import itertools import sys +from typing import (Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, + TypeVar, Union, cast) -from mypy.types import ( - TupleType, Instance, FunctionLike, Type, CallableType, TypeVarLikeType, Overloaded, - TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, - AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, - TypeAliasType, TypeQuery, ParamSpecType, Parameters, UnpackType, TypeVarTupleType, - ENUM_REMOVED_PROPS, -) -from mypy.nodes import ( - FuncBase, FuncItem, FuncDef, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS, - Expression, StrExpr, Var, Decorator, SYMBOL_FUNCBASE_TYPES -) -from mypy.maptype import map_instance_to_supertype -from mypy.expandtype import expand_type_by_instance, expand_type -from mypy.copytype import copy_type - -from mypy.typevars import fill_typevars +from typing_extensions import Type as TypingType +from mypy.copytype import copy_type +from mypy.expandtype import expand_type, expand_type_by_instance +from mypy.maptype import map_instance_to_supertype +from mypy.nodes import (ARG_POS, ARG_STAR, ARG_STAR2, SYMBOL_FUNCBASE_TYPES, + Decorator, Expression, FuncBase, FuncDef, FuncItem, + OverloadedFuncDef, StrExpr, TypeInfo, Var) from mypy.state import state +from mypy.type_visitor import SelfTypeVisitor +from mypy.types import (ENUM_REMOVED_PROPS, AnyType, CallableType, + FormalArgument, FunctionLike, Instance, LiteralType, + NoneType, Overloaded, Parameters, ParamSpecType, + ProperType, TupleType, Type, TypeAliasType, TypeOfAny, + TypeQuery, TypeType, TypeVarLikeType, TypeVarTupleType, + TypeVarType, UninhabitedType, UnionType, UnpackType, + get_proper_type, get_proper_types) +from mypy.typevars import fill_typevars def is_recursive_pair(s: Type, t: Type) -> bool: @@ -274,6 +274,10 @@ def expand(target: Type) -> Type: variables=variables, ret_type=ret_type, bound_args=[original_type]) + if original_type: + if isinstance(original_type, TypeType): + original_type = original_type.item + res.accept(SelfTypeVisitor(original_type)) return cast(F, res) diff --git a/mypy/types.py b/mypy/types.py index f0f7add2d92f..6653b71da5e4 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -104,6 +104,12 @@ 'typing_extensions.final', ) +# Supported Self type names. +SELF_TYPE_NAMES: Final = ( + 'typing.Self', + 'typing_extensions.Self', +) + # Supported Literal type names. LITERAL_TYPE_NAMES: Final = ( 'typing.Literal', @@ -549,6 +555,34 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarType': ) +class SelfType(ProperType): + name: ClassVar[str] = 'Self' + + __slots__ = ('fullname', 'instance') + + def __init__(self, instance: 'Instance', fullname: str, + line: int = -1, column: int = -1) -> None: + super().__init__(line, column) + self.fullname = fullname + self.instance = instance + + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_self_type(self) + + def serialize(self) -> JsonDict: + return { + '.class': 'SelfType', + 'fullname': self.fullname, + 'instance': self.instance.serialize(), + } + + @classmethod + def deserialize(cls, data: JsonDict) -> 'SelfType': + assert data['.class'] == 'SelfType' + assert isinstance(data['instance'], str) + return SelfType(Instance.deserialize(data['instance']), data['fullname']) + + class ParamSpecFlavor: # Simple ParamSpec reference such as "P" BARE: Final = 0 @@ -2659,6 +2693,9 @@ def visit_type_var(self, t: TypeVarType) -> str: s += f'(upper_bound={t.upper_bound.accept(self)})' return s + def visit_self_type(self, t: SelfType) -> str: + return "Self@{}".format(t.instance.accept(self)) if t.instance else "Self@unbound" + def visit_param_spec(self, t: ParamSpecType) -> str: # prefixes are displayed as Concatenate s = '' diff --git a/mypy/typeshed/stdlib/typing.pyi b/mypy/typeshed/stdlib/typing.pyi index 37ea55c9f2ef..45ff25bf411c 100644 --- a/mypy/typeshed/stdlib/typing.pyi +++ b/mypy/typeshed/stdlib/typing.pyi @@ -581,6 +581,11 @@ if sys.version_info >= (3, 10): else: def NewType(name: str, tp: Any) -> Any: ... +if sys.version_info >= (3, 11): + # Self is also a (non-subscriptable) special form. + ... +Self: object = ... + # These type variables are used by the container types. _S = TypeVar("_S") _KT = TypeVar("_KT") # Key type. diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 7d959c97b66b..57d20b6a3f50 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -7,7 +7,7 @@ TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, PlaceholderType, PartialType, RawExpressionType, TypeAliasType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + UnpackType, TypeVarTupleType, SelfType, ) @@ -55,6 +55,9 @@ def visit_literal_type(self, t: LiteralType) -> None: def visit_instance(self, t: Instance) -> None: self.traverse_types(t.args) + def visit_self_type(self, t: SelfType) -> None: + self.visit_instance(t.instance) + def visit_callable_type(self, t: CallableType) -> None: # FIX generics self.traverse_types(t.arg_types) diff --git a/runner.py b/runner.py new file mode 100644 index 000000000000..6960269c86be --- /dev/null +++ b/runner.py @@ -0,0 +1,13 @@ +import sys + +from mypy.version import __version__ +from mypy.build import build, BuildSource, Options + +print(__version__) + +options = Options() +options.show_traceback = True +options.raise_exceptions = True +# options.verbosity = 10 +result = build([BuildSource("test.py", None, )], options, stderr=sys.stderr, stdout=sys.stdout) +print(*result.errors, sep="\n") diff --git a/test-data/unit/check-selftyping.test b/test-data/unit/check-selftyping.test new file mode 100644 index 000000000000..e46d8a3cc524 --- /dev/null +++ b/test-data/unit/check-selftyping.test @@ -0,0 +1,14 @@ +-- PEP 673 -- +[case testSelfTypeMethodReturn] +from typing import Self +class C: + def m(self) -> Self: + return self + +reveal_type(C().m()) +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-full.pyi] +[out] +main:6: note: Revealed type is "__main__.C" + + diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 66b02638ebc7..65310e083e63 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -31,6 +31,7 @@ Literal = 0 TypedDict = 0 NoReturn = 0 NewType = 0 +Self = 0 T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) diff --git a/test.py b/test.py new file mode 100644 index 000000000000..1339739459ec --- /dev/null +++ b/test.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Generic, TypeVar +from typing_extensions import Self +from abc import ABC + +T = TypeVar("T") +K = TypeVar("K") + + +class ItemSet(Generic[T]): + def first(self) -> T: ... + + +class BaseItem(ABC): + @property + def set(self) -> ItemSet[Self]: ... + + +class FooItem(BaseItem): + name: str + + def test(self) -> None: ... + + +reveal_type(FooItem().set.first().name) +reveal_type(BaseItem().set) +reveal_type(FooItem().set.first().test())