Skip to content

Commit

Permalink
[dataclasses plugin] Support kw_only=True
Browse files Browse the repository at this point in the history
  • Loading branch information
tgallant committed Jul 24, 2021
1 parent 97b3b90 commit dbfd425
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 20 deletions.
21 changes: 19 additions & 2 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
column: int,
type: Optional[Type],
info: TypeInfo,
kw_only: bool,
) -> None:
self.name = name
self.is_in_init = is_in_init
Expand All @@ -48,6 +49,7 @@ def __init__(
self.column = column
self.type = type
self.info = info
self.kw_only = kw_only

def to_argument(self) -> Argument:
return Argument(
Expand Down Expand Up @@ -77,6 +79,8 @@ def deserialize(
cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface
) -> 'DataclassAttribute':
data = data.copy()
if data.get('kw_only') is None:
data['kw_only'] = False
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data)

Expand Down Expand Up @@ -215,6 +219,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
cls = self._ctx.cls
attrs: List[DataclassAttribute] = []
known_attrs: Set[str] = set()
kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False)
for stmt in cls.defs.body:
# Any assignment that doesn't use the new type declaration
# syntax can be ignored out of hand.
Expand Down Expand Up @@ -251,6 +256,10 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
is_init_var = True
node.type = node_type.args[0]

if (isinstance(node_type, Instance) and
node_type.type.fullname == 'dataclasses._KW_ONLY_TYPE'):
kw_only = True

has_field_call, field_args = _collect_field_args(stmt.rvalue)

is_in_init_param = field_args.get('init')
Expand All @@ -274,6 +283,13 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# on self in the generated __init__(), not in the class body.
sym.implicit = True

is_kw_only = kw_only
# Use the kw_only field arg if it is provided. Otherwise use the
# kw_only value from the decorator parameter.
field_kw_only_param = field_args.get('kw_only')
if field_kw_only_param is not None:
is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param))

known_attrs.add(lhs.name)
attrs.append(DataclassAttribute(
name=lhs.name,
Expand All @@ -284,6 +300,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
column=stmt.column,
type=sym.type,
info=cls.info,
kw_only=is_kw_only,
))

# Next, collect attributes belonging to any class in the MRO
Expand Down Expand Up @@ -323,10 +340,10 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# arguments that have a default.
found_default = False
for attr in all_attrs:
# If we find any attribute that is_in_init but that
# If we find any attribute that is_in_init, not kw_only, and that
# doesn't have a default after one that does have one,
# then that's an error.
if found_default and attr.is_in_init and not attr.has_default:
if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only:
# If the issue comes from merging different classes, report it
# at the class definition point.
context = (Context(line=attr.line, column=attr.column) if attr in attrs
Expand Down
59 changes: 58 additions & 1 deletion test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class Person:
name: str
age: int = field(init=None) # E: No overload variant of "field" matches argument type "None" \
# N: Possible overload variant: \
# N: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ...) -> Any \
# N: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> Any \
# N: <2 more non-matching overloads not shown>

[builtins fixtures/list.pyi]
Expand Down Expand Up @@ -311,6 +311,63 @@ class Application:

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnly]
# flags: --python-version 3.10
from dataclasses import dataclass

@dataclass(kw_only=True)
class Application:
name: str = 'Unnamed'
rating: int

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyOnField]
# flags: --python-version 3.10
from dataclasses import dataclass, field

@dataclass
class Application:
name: str = 'Unnamed'
rating: int = field(kw_only=True)

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyOnFieldFalse]
# flags: --python-version 3.10
from dataclasses import dataclass, field

@dataclass
class Application:
name: str = 'Unnamed'
rating: int = field(kw_only=False) # E: Attributes without a default cannot follow attributes with one

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyWithSentinel]
# flags: --python-version 3.10
from dataclasses import dataclass, KW_ONLY

@dataclass
class Application:
_: KW_ONLY
name: str = 'Unnamed'
rating: int

[builtins fixtures/list.pyi]

[case testDataclassesOrderingKwOnlyWithSentinelAndFieldOverride]
# flags: --python-version 3.10
from dataclasses import dataclass, field, KW_ONLY

@dataclass
class Application:
_: KW_ONLY
name: str = 'Unnamed'
rating: int = field(kw_only=False) # E: Attributes without a default cannot follow attributes with one

[builtins fixtures/list.pyi]

[case testDataclassesClassmethods]
# flags: --python-version 3.7
from dataclasses import dataclass
Expand Down
64 changes: 47 additions & 17 deletions test-data/unit/lib-stub/dataclasses.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,56 @@ _T = TypeVar('_T')
class InitVar(Generic[_T]):
...

class _KW_ONLY_TYPE: ...
KW_ONLY = _KW_ONLY_TYPE

@overload
def dataclass(_cls: Type[_T]) -> Type[_T]: ...

def dataclass(__cls: Type[_T]) -> Type[_T]: ...
@overload
def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool = ...,
unsafe_hash: bool = ..., frozen: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ...


def dataclass(__cls: None) -> Callable[[Type[_T]], Type[_T]]: ...
@overload
def field(*, default: _T,
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> _T: ...

def dataclass(
*,
init: bool = ...,
repr: bool = ...,
eq: bool = ...,
order: bool = ...,
unsafe_hash: bool = ...,
frozen: bool = ...,
match_args: bool = ...,
kw_only: bool = ...,
slots: bool = ...,
) -> Callable[[Type[_T]], Type[_T]]: ...

@overload # `default` and `default_factory` are optional and mutually exclusive.
def field(
*,
default: _T,
init: bool = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...,
kw_only: bool = ...,
) -> _T: ...
@overload
def field(*, default_factory: Callable[[], _T],
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> _T: ...

def field(
*,
default_factory: Callable[[], _T],
init: bool = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...,
kw_only: bool = ...,
) -> _T: ...
@overload
def field(*,
init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...) -> Any: ...
def field(
*,
init: bool = ...,
repr: bool = ...,
hash: Optional[bool] = ...,
compare: bool = ...,
metadata: Optional[Mapping[str, Any]] = ...,
kw_only: bool = ...,
) -> Any: ...

0 comments on commit dbfd425

Please sign in to comment.