Skip to content

Commit

Permalink
Add some type hints to the setuptools.config subpackage
Browse files Browse the repository at this point in the history
(Merge branch 'add-some-type-hints' into support-pyproject)

This helps to ensure confidence in the code and prevent unexpected bugs.

To typecheck the subpackage the following command can be used:

		mypy setuptools/config --tb --show-error-codes --show-error-context --pretty
  • Loading branch information
abravalheri committed Jan 7, 2022
2 parents 2ff6d2f + ede9590 commit 89a19a4
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 66 deletions.
12 changes: 7 additions & 5 deletions setuptools/config/_apply_pyprojecttoml.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@
from functools import partial
from itertools import chain
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple,
Type, Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional,
Set, Tuple, Type, Union, cast)

if TYPE_CHECKING:
from pkg_resources import EntryPoint # noqa
from setuptools.dist import Distribution # noqa

EMPTY = MappingProxyType({}) # Immutable dict-like
EMPTY: Mapping = MappingProxyType({}) # Immutable dict-like
_Path = Union[os.PathLike, str]
_DictOrStr = Union[dict, str]
_CorrespFn = Callable[["Distribution", Any, _Path], None]
_Correspondence = Union[str, _CorrespFn]
_IterEntryPoints = Callable[[str], List["EntryPoint"]]


def apply(dist: "Distribution", config: dict, filename: _Path) -> "Distribution":
Expand Down Expand Up @@ -186,8 +187,9 @@ def _valid_command_options(cmdclass: Mapping = EMPTY) -> Dict[str, Set[str]]:

valid_options = {"global": _normalise_cmd_options(Distribution.global_options)}

entry_points = (_load_ep(ep) for ep in iter_entry_points('distutils.commands'))
entry_points = (ep for ep in entry_points if ep)
iter_ep = cast(_IterEntryPoints, iter_entry_points)
_entry_points = (_load_ep(ep) for ep in iter_ep('distutils.commands'))
entry_points = (ep for ep in _entry_points if ep)
for cmd, cmd_class in chain(entry_points, cmdclass.items()):
opts = valid_options.get(cmd, set())
opts = opts | _normalise_cmd_options(getattr(cmd_class, "user_options", []))
Expand Down
78 changes: 52 additions & 26 deletions setuptools/config/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@
import sys
from glob import iglob
from configparser import ConfigParser
from importlib.machinery import ModuleSpec
from itertools import chain
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
from types import ModuleType

from distutils.errors import DistutilsOptionError

chain_iter = chain.from_iterable
_Path = Union[str, os.PathLike]


class StaticModule:
"""
Attempt to load the module by the name
"""

def __init__(self, name, spec):
with open(spec.origin) as strm:
def __init__(self, name: str, spec: ModuleSpec):
with open(spec.origin) as strm: # type: ignore
src = strm.read()
module = ast.parse(src)
vars(self).update(locals())
Expand All @@ -56,7 +60,9 @@ def __getattr__(self, attr):
) from e


def glob_relative(patterns, root_dir=None):
def glob_relative(
patterns: Iterable[str], root_dir: Optional[_Path] = None
) -> List[str]:
"""Expand the list of glob patterns, but preserving relative paths.
:param list[str] patterns: List of glob patterns
Expand Down Expand Up @@ -85,15 +91,15 @@ def glob_relative(patterns, root_dir=None):
return expanded_values


def read_files(filepaths, root_dir=None):
def read_files(filepaths: Union[str, bytes, Iterable[_Path]], root_dir=None) -> str:
"""Return the content of the files concatenated using ``\n`` as str
This function is sandboxed and won't reach anything outside ``root_dir``
(By default ``root_dir`` is the current directory).
"""
if isinstance(filepaths, (str, bytes)):
filepaths = [filepaths]
filepaths = [filepaths] # type: ignore

root_dir = os.path.abspath(root_dir or os.getcwd())
_filepaths = (os.path.join(root_dir, path) for path in filepaths)
Expand All @@ -104,20 +110,24 @@ def read_files(filepaths, root_dir=None):
)


def _read_file(filepath):
def _read_file(filepath: Union[bytes, _Path]) -> str:
with io.open(filepath, encoding='utf-8') as f:
return f.read()


def _assert_local(filepath, root_dir):
def _assert_local(filepath: _Path, root_dir: str):
if not os.path.abspath(filepath).startswith(root_dir):
msg = f"Cannot access {filepath!r} (or anything outside {root_dir!r})"
raise DistutilsOptionError(msg)

return True


def read_attr(attr_desc, package_dir=None, root_dir=None):
def read_attr(
attr_desc: str,
package_dir: Optional[dict] = None,
root_dir: Optional[_Path] = None
):
"""Reads the value of an attribute from a module.
This function will try to read the attributed statically first
Expand All @@ -140,8 +150,8 @@ def read_attr(attr_desc, package_dir=None, root_dir=None):
attr_name = attrs_path.pop()
module_name = '.'.join(attrs_path)
module_name = module_name or '__init__'
parent_path, path, module_name = _find_module(module_name, package_dir, root_dir)
spec = _find_spec(module_name, path, parent_path)
_parent_path, path, module_name = _find_module(module_name, package_dir, root_dir)
spec = _find_spec(module_name, path)

try:
return getattr(StaticModule(module_name, spec), attr_name)
Expand All @@ -151,7 +161,7 @@ def read_attr(attr_desc, package_dir=None, root_dir=None):
return getattr(module, attr_name)


def _find_spec(module_name, module_path, parent_path):
def _find_spec(module_name: str, module_path: Optional[_Path]) -> ModuleSpec:
spec = importlib.util.spec_from_file_location(module_name, module_path)
spec = spec or importlib.util.find_spec(module_name)

Expand All @@ -161,17 +171,19 @@ def _find_spec(module_name, module_path, parent_path):
return spec


def _load_spec(spec, module_name):
def _load_spec(spec: ModuleSpec, module_name: str) -> ModuleType:
name = getattr(spec, "__name__", module_name)
if name in sys.modules:
return sys.modules[name]
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module # cache (it also ensures `==` works on loaded items)
spec.loader.exec_module(module)
spec.loader.exec_module(module) # type: ignore
return module


def _find_module(module_name, package_dir, root_dir):
def _find_module(
module_name: str, package_dir: Optional[dict], root_dir: _Path
) -> Tuple[_Path, Optional[str], str]:
"""Given a module (that could normally be imported by ``module_name``
after the build is complete), find the path to the parent directory where
it is contained and the canonical name that could be used to import it
Expand Down Expand Up @@ -203,26 +215,36 @@ def _find_module(module_name, package_dir, root_dir):
return parent_path, module_path, module_name


def resolve_class(qualified_class_name, package_dir=None, root_dir=None):
def resolve_class(
qualified_class_name: str,
package_dir: Optional[dict] = None,
root_dir: Optional[_Path] = None
) -> Callable:
"""Given a qualified class name, return the associated class object"""
root_dir = root_dir or os.getcwd()
idx = qualified_class_name.rfind('.')
class_name = qualified_class_name[idx + 1 :]
pkg_name = qualified_class_name[:idx]

parent_path, path, module_name = _find_module(pkg_name, package_dir, root_dir)
module = _load_spec(_find_spec(module_name, path, parent_path), module_name)
_parent_path, path, module_name = _find_module(pkg_name, package_dir, root_dir)
module = _load_spec(_find_spec(module_name, path), module_name)
return getattr(module, class_name)


def cmdclass(values, package_dir=None, root_dir=None):
def cmdclass(
values: Dict[str, str],
package_dir: Optional[dict] = None,
root_dir: Optional[_Path] = None
) -> Dict[str, Callable]:
"""Given a dictionary mapping command names to strings for qualified class
names, apply :func:`resolve_class` to the dict values.
"""
return {k: resolve_class(v, package_dir, root_dir) for k, v in values.items()}


def find_packages(*, namespaces=False, root_dir=None, **kwargs):
def find_packages(
*, namespaces=False, root_dir: Optional[_Path] = None, **kwargs
) -> List[str]:
"""Works similarly to :func:`setuptools.find_packages`, but with all
arguments given as keyword arguments. Moreover, ``where`` can be given
as a list (the results will be simply concatenated).
Expand All @@ -237,7 +259,7 @@ def find_packages(*, namespaces=False, root_dir=None, **kwargs):
if namespaces:
from setuptools import PEP420PackageFinder as PackageFinder
else:
from setuptools import PackageFinder
from setuptools import PackageFinder # type: ignore

root_dir = root_dir or "."
where = kwargs.pop('where', ['.'])
Expand All @@ -247,18 +269,20 @@ def find_packages(*, namespaces=False, root_dir=None, **kwargs):
return list(chain_iter(PackageFinder.find(x, **kwargs) for x in target))


def _nest_path(parent, path):
def _nest_path(parent: _Path, path: _Path) -> str:
path = parent if path == "." else os.path.join(parent, path)
return os.path.normpath(path)


def version(value):
def version(value: Union[Callable, Iterable[Union[str, int]], str]) -> str:
"""When getting the version directly from an attribute,
it should be normalised to string.
"""
if callable(value):
value = value()

value = cast(Iterable[Union[str, int]], value)

if not isinstance(value, str):
if hasattr(value, '__iter__'):
value = '.'.join(map(str, value))
Expand All @@ -268,13 +292,15 @@ def version(value):
return value


def canonic_package_data(package_data):
def canonic_package_data(package_data: dict) -> dict:
if "*" in package_data:
package_data[""] = package_data.pop("*")
return package_data


def canonic_data_files(data_files, root_dir=None):
def canonic_data_files(
data_files: Union[list, dict], root_dir: Optional[_Path] = None
) -> List[Tuple[str, List[str]]]:
"""For compatibility with ``setup.py``, ``data_files`` should be a list
of pairs instead of a dict.
Expand All @@ -289,14 +315,14 @@ def canonic_data_files(data_files, root_dir=None):
]


def entry_points(text, text_source="entry-points"):
def entry_points(text: str, text_source="entry-points") -> Dict[str, dict]:
"""Given the contents of entry-points file,
process it into a 2-level dictionary (``dict[str, dict[str, str]]``).
The first level keys are entry-point groups, the second level keys are
entry-point names, and the second level values are references to objects
(that correspond to the entry-point value).
"""
parser = ConfigParser(default_section=None, delimiters=("=",))
parser = ConfigParser(default_section=None, delimiters=("=",)) # type: ignore
parser.optionxform = str # case sensitive
parser.read_string(text, text_source)
groups = {k: dict(v.items()) for k, v in parser.items()}
Expand Down
38 changes: 24 additions & 14 deletions setuptools/config/pyprojecttoml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import contextmanager
from distutils import log
from functools import partial
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Callable, Optional, Union

from setuptools.errors import FileError, OptionError

Expand All @@ -18,15 +18,15 @@
_Path = Union[str, os.PathLike]


def load_file(filepath: _Path):
def load_file(filepath: _Path) -> dict:
try:
from setuptools.extern import tomli
from setuptools.extern import tomli # type: ignore
except ImportError: # Bootstrap problem (?) diagnosed by test_distutils_adoption
sys_path = sys.path.copy()
try:
from setuptools import _vendor
sys.path.append(_vendor.__path__[0])
import tomli
sys.path.append(_vendor.__path__[0]) # type: ignore
import tomli # type: ignore
finally:
sys.path = sys_path

Expand Down Expand Up @@ -61,7 +61,7 @@ def apply_configuration(dist: "Distribution", filepath: _Path) -> "Distribution"
return apply(dist, config, filepath)


def read_configuration(filepath, expand=True, ignore_option_errors=False):
def read_configuration(filepath: _Path, expand=True, ignore_option_errors=False):
"""Read given configuration file and returns options from it as a dict.
:param str|unicode filepath: Path to configuration file in the ``pyproject.toml``
Expand Down Expand Up @@ -104,7 +104,9 @@ def read_configuration(filepath, expand=True, ignore_option_errors=False):
return asdict


def expand_configuration(config, root_dir=None, ignore_option_errors=False):
def expand_configuration(
config: dict, root_dir: Optional[_Path] = None, ignore_option_errors=False
) -> dict:
"""Given a configuration with unresolved fields (e.g. dynamic, cmdclass, ...)
find their final values.
Expand Down Expand Up @@ -134,7 +136,9 @@ def expand_configuration(config, root_dir=None, ignore_option_errors=False):
return config


def _expand_all_dynamic(project_cfg, setuptools_cfg, root_dir, ignore_option_errors):
def _expand_all_dynamic(
project_cfg: dict, setuptools_cfg: dict, root_dir: _Path, ignore_option_errors: bool
):
silent = ignore_option_errors
dynamic_cfg = setuptools_cfg.get("dynamic", {})
package_dir = setuptools_cfg.get("package-dir", None)
Expand All @@ -161,7 +165,10 @@ def _expand_all_dynamic(project_cfg, setuptools_cfg, root_dir, ignore_option_err
project_cfg.update(_expand_entry_points(value, dynamic))


def _expand_dynamic(dynamic_cfg, field, package_dir, root_dir, ignore_option_errors):
def _expand_dynamic(
dynamic_cfg: dict, field: str, package_dir: Optional[dict],
root_dir: _Path, ignore_option_errors: bool
):
if field in dynamic_cfg:
directive = dynamic_cfg[field]
if "file" in directive:
Expand All @@ -175,15 +182,15 @@ def _expand_dynamic(dynamic_cfg, field, package_dir, root_dir, ignore_option_err
return None


def _expand_readme(dynamic_cfg, root_dir, ignore_option_errors):
def _expand_readme(dynamic_cfg: dict, root_dir: _Path, ignore_option_errors: bool):
silent = ignore_option_errors
return {
"text": _expand_dynamic(dynamic_cfg, "readme", None, root_dir, silent),
"content-type": dynamic_cfg["readme"].get("content-type", "text/x-rst")
}


def _expand_entry_points(text, dynamic):
def _expand_entry_points(text: str, dynamic: set):
groups = _expand.entry_points(text)
expanded = {"entry-points": groups}
if "scripts" in dynamic and "console_scripts" in groups:
Expand All @@ -193,7 +200,7 @@ def _expand_entry_points(text, dynamic):
return expanded


def _expand_packages(setuptools_cfg, root_dir, ignore_option_errors=False):
def _expand_packages(setuptools_cfg: dict, root_dir: _Path, ignore_option_errors=False):
packages = setuptools_cfg.get("packages")
if packages is None or isinstance(packages, (list, tuple)):
return
Expand All @@ -205,7 +212,10 @@ def _expand_packages(setuptools_cfg, root_dir, ignore_option_errors=False):
setuptools_cfg["packages"] = _expand.find_packages(**find)


def _process_field(container, field, fn, ignore_option_errors=False):
def _process_field(
container: dict, field: str,
fn: Callable, ignore_option_errors=False
):
if field in container:
with _ignore_errors(ignore_option_errors):
container[field] = fn(container[field])
Expand All @@ -217,7 +227,7 @@ def _canonic_package_data(setuptools_cfg, field="package-data"):


@contextmanager
def _ignore_errors(ignore_option_errors):
def _ignore_errors(ignore_option_errors: bool):
if not ignore_option_errors:
yield
return
Expand Down
Loading

0 comments on commit 89a19a4

Please sign in to comment.