Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pyproject.toml #246

Merged
merged 16 commits into from
Mar 4, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion pip_audit/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PYPI_URL,
DependencySource,
PipSource,
PyProjectSource,
RequirementSource,
ResolveLibResolver,
)
Expand Down Expand Up @@ -159,6 +160,9 @@ def _parser() -> argparse.ArgumentParser:
dest="requirements",
help="audit the given requirements file; this option can be used multiple times",
)
dep_source_args.add_argument(
"project_path", type=Path, nargs="?", help="audit a local Python project at the given path"
)
parser.add_argument(
"-f",
"--format",
Expand Down Expand Up @@ -269,6 +273,15 @@ def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
return parser.parse_args()


def _create_project_source(project_path: Path, state: AuditState) -> Optional[DependencySource]:
# Check for a `pyproject.toml`
pyproject_path = project_path.joinpath("pyproject.toml")
if pyproject_path.is_file():
return PyProjectSource(pyproject_path, ResolveLibResolver(), state)
else:
return None


def audit() -> None:
"""
The primary entrypoint for `pip-audit`.
Expand Down Expand Up @@ -301,15 +314,21 @@ def audit() -> None:
state = stack.enter_context(AuditState(members=actors))

source: DependencySource
index_urls = [args.index_url] + args.extra_index_urls
if args.requirements is not None:
index_urls = [args.index_url] + args.extra_index_urls
req_files: List[Path] = [Path(req.name) for req in args.requirements]
source = RequirementSource(
req_files,
ResolveLibResolver(index_urls, args.timeout, args.cache_dir, state),
args.require_hashes,
state,
)
elif args.project_path is not None:
# Determine which kind of project file exists in the project path
_source = _create_project_source(args.project_path, state)
if _source is None:
parser.error(f"couldn't find project file at path: {args.project_path}")
source = _source
woodruffw marked this conversation as resolved.
Show resolved Hide resolved
else:
source = PipSource(local=args.local, paths=args.paths, state=state)

Expand Down
2 changes: 2 additions & 0 deletions pip_audit/_dependency_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DependencySourceError,
)
from .pip import PipSource, PipSourceError
from .pyproject import PyProjectSource
from .requirement import RequirementSource
from .resolvelib import PYPI_URL, ResolveLibResolver

Expand All @@ -22,6 +23,7 @@
"DependencySourceError",
"PipSource",
"PipSourceError",
"PyProjectSource",
"RequirementSource",
"ResolveLibResolver",
]
147 changes: 147 additions & 0 deletions pip_audit/_dependency_source/pyproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
Collect dependencies from `pyproject.toml` files.
"""

import logging
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Iterator, List, Set, cast

import toml
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet

from pip_audit._dependency_source import (
DependencyFixError,
DependencyResolver,
DependencyResolverError,
DependencySource,
DependencySourceError,
)
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState

logger = logging.getLogger(__name__)


class PyProjectSource(DependencySource):
"""
Wraps `pyproject.toml` dependency resolution as a dependency source.
"""

def __init__(
self, filename: Path, resolver: DependencyResolver, state: AuditState = AuditState()
) -> None:
"""
Create a new `PyProjectSource`.

`filename` provides a path to a `pyproject.toml` file

`resolver` is the `DependencyResolver` to use.

`state` is an `AuditState` to use for state callbacks.
"""
self.filename = filename
self.resolver = resolver
self.state = state

def collect(self) -> Iterator[Dependency]:
"""
Collect all of the dependencies discovered by this `PyProjectSource`.

Raises a `PyProjectSourceError` on any errors.
"""

collected: Set[Dependency] = set()
with open(self.filename, "r") as f:
pyproject_data = toml.load(f)
if "project" not in pyproject_data:
raise PyProjectSourceError(
f"pyproject file {self.filename} does not contain `project` section"
)
project = pyproject_data["project"]
if "dependencies" not in project:
# Projects without dependencies aren't an error case
logger.warning(
f"pyproject file {self.filename} does not contain `dependencies` list"
)
return
deps = project["dependencies"]
reqs: List[Requirement] = [Requirement(dep) for dep in deps]
try:
for _, deps in self.resolver.resolve_all(iter(reqs)):
for dep in deps:
# Don't allow duplicate dependencies to be returned
if dep in collected:
continue

if dep.is_skipped(): # pragma: no cover
dep = cast(SkippedDependency, dep)
self.state.update_state(f"Skipping {dep.name}: {dep.skip_reason}")
else:
dep = cast(ResolvedDependency, dep)
self.state.update_state(f"Collecting {dep.name} ({dep.version})")

collected.add(dep)
yield dep
except DependencyResolverError as dre:
raise PyProjectSourceError("dependency resolver raised an error") from dre

def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Fixes a dependency version for this `PyProjectSource`.
"""

with open(self.filename, "r+") as f, NamedTemporaryFile(mode="r+") as tmp:
pyproject_data = toml.load(f)
if "project" not in pyproject_data:
raise PyProjectFixError(
f"pyproject file {self.filename} does not contain `project` section"
)
project = pyproject_data["project"]
if "dependencies" not in project:
# Projects without dependencies aren't an error case
logger.warning(
f"pyproject file {self.filename} does not contain `dependencies` list"
)
return

deps = project["dependencies"]
reqs: List[Requirement] = [Requirement(dep) for dep in deps]
for i in range(len(reqs)):
# When we find a requirement that matches the provided fix version, we need to edit
# the requirement's specifier and then write it back to the underlying TOML data.
req = reqs[i]
if (
req.name == fix_version.dep.name
and req.specifier.contains(fix_version.dep.version)
and not req.specifier.contains(fix_version.version)
):
req.specifier = SpecifierSet(f"=={fix_version.version}")
deps[i] = str(req)
assert req.marker is None or req.marker.evaluate()

# Now dump the new edited TOML to the temporary file.
toml.dump(pyproject_data, tmp)

# And replace the original `pyproject.toml` file.
os.replace(tmp.name, self.filename)

# Stop the file wrapper from attempting to cleanup if we've successfully moved the
# temporary file into place.
tmp._closer.delete = False


class PyProjectSourceError(DependencySourceError):
"""A `pyproject.toml` specific `DependencySourceError`."""

pass


class PyProjectFixError(DependencyFixError):
"""A `pyproject.toml` specific `DependencyFixError`."""

pass
110 changes: 110 additions & 0 deletions test/dependency_source/test_pyproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from pathlib import Path
from typing import List

import pretend # type: ignore
import pytest
from packaging.requirements import Requirement
from packaging.version import Version

from pip_audit._dependency_source import (
DependencyResolver,
DependencyResolverError,
DependencySourceError,
ResolveLibResolver,
pyproject,
)
from pip_audit._service import Dependency, ResolvedDependency


def _init_pyproject(filename: Path, contents: str) -> Path:
with open(filename, mode="w") as f:
f.write(contents)
return filename


@pytest.mark.online
def test_pyproject_source(req_file):
filename = _init_pyproject(
req_file(),
"""
[project]
dependencies = [
"flask==2.0.1"
]
""",
)
source = pyproject.PyProjectSource(filename, ResolveLibResolver())
specs = list(source.collect())
assert ResolvedDependency("flask", Version("2.0.1")) in specs


def test_pyproject_source_no_project_section(req_file):
filename = _init_pyproject(
req_file(),
"""
[some_other_section]
dependencies = [
"flask==2.0.1"
]
""",
)
source = pyproject.PyProjectSource(filename, ResolveLibResolver())
with pytest.raises(DependencySourceError):
list(source.collect())


def test_pyproject_source_no_deps(monkeypatch, req_file):
logger = pretend.stub(warning=pretend.call_recorder(lambda s: None))
monkeypatch.setattr(pyproject, "logger", logger)

filename = _init_pyproject(
req_file(),
"""
[project]
""",
)
source = pyproject.PyProjectSource(filename, ResolveLibResolver())
specs = list(source.collect())
assert not specs

# We log a warning when we find a `pyproject.toml` file with no dependencies
assert len(logger.warning.calls) == 1


def test_pyproject_source_duplicate_deps(req_file):
# Click is a dependency of Flask. We should check that the dependencies of Click aren't returned
# twice.
filename = _init_pyproject(
req_file(),
"""
[project]
dependencies = [
"flask",
"click",
]
""",
)
source = pyproject.PyProjectSource(filename, ResolveLibResolver())
specs = list(source.collect())

# Check that the list of dependencies is already deduplicated
assert len(specs) == len(set(specs))


def test_pyproject_source_resolver_error(monkeypatch, req_file):
class MockResolver(DependencyResolver):
def resolve(self, req: Requirement) -> List[Dependency]:
raise DependencyResolverError

filename = _init_pyproject(
req_file(),
"""
[project]
dependencies = [
"flask==2.0.1"
]
""",
)
source = pyproject.PyProjectSource(filename, MockResolver())
with pytest.raises(DependencySourceError):
list(source.collect())