Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

requirement: Implement --fix for RequirementSource #225

Merged
merged 16 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ All versions prior to 0.0.9 are untracked.
* CLI: The `--fix` flag has been added, allowing users to attempt to
automatically upgrade any vulnerable dependencies to the first safe version
available ([#212](https://github.com/trailofbits/pip-audit/pull/212),
[#222](https://github.com/trailofbits/pip-audit/pull/222))
[#222](https://github.com/trailofbits/pip-audit/pull/222),
[#225](https://github.com/trailofbits/pip-audit/pull/225))

* CLI: The combination of `--fix` and `--dry-run` is now supported, causing
`pip-audit` to perform the auditing step but not any resulting fix steps
Expand Down
80 changes: 77 additions & 3 deletions pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
Collect dependencies from one or more `requirements.txt`-formatted files.
"""

import logging
import os
import shutil
from contextlib import ExitStack
from pathlib import Path
from typing import Iterator, List, Set, cast
from tempfile import NamedTemporaryFile
from typing import IO, Iterator, List, Set, cast

from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from pip_api import parse_requirements
from pip_api.exceptions import PipError

from pip_audit._dependency_source import (
DependencyFixError,
DependencyResolver,
DependencyResolverError,
DependencySource,
Expand All @@ -20,6 +27,8 @@
from pip_audit._service.interface import ResolvedDependency, SkippedDependency
from pip_audit._state import AuditState

logger = logging.getLogger(__name__)


class RequirementSource(DependencySource):
"""
Expand Down Expand Up @@ -79,14 +88,79 @@ def collect(self) -> Iterator[Dependency]:
except DependencyResolverError as dre:
raise RequirementSourceError("dependency resolver raised an error") from dre

def fix(self, fix_version: ResolvedFixVersion) -> None: # pragma: no cover
def fix(self, fix_version: ResolvedFixVersion) -> None:
"""
Fixes a dependency version for this `RequirementSource`.
"""
raise NotImplementedError
with ExitStack() as stack:
# Make temporary copies of the existing requirements files. If anything goes wrong, we
# want to copy them back into place and undo any partial application of the fix.
tmp_files: List[IO[str]] = [
stack.enter_context(NamedTemporaryFile(mode="w")) for _ in self.filenames
]
for (filename, tmp_file) in zip(self.filenames, tmp_files):
with filename.open("r") as f:
shutil.copyfileobj(f, tmp_file)

try:
# Now fix the files inplace
for filename in self.filenames:
self.state.update_state(
f"Fixing dependency {fix_version.dep.name} ({fix_version.dep.version} => "
f"{fix_version.version})"
)
self._fix_file(filename, fix_version)
except Exception as e:
logger.warning(
f"encountered an exception while applying fixes, recovering original files: {e}"
)
self._recover_files(tmp_files)
raise e

def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None:
# Reparse the requirements file. We want to rewrite each line to the new requirements file
# and only modify the lines that we're fixing.
try:
reqs = parse_requirements(filename=filename)
except PipError as pe:
raise RequirementFixError(f"requirement parsing raised an error: {filename}") from pe

# Convert requirements types from pip-api's vendored types to our own
req_list: List[Requirement] = [Requirement(str(req)) for req in reqs.values()]

# Now write out the new requirements file
with filename.open("w") as f:
for req in req_list:
if (
req.name == fix_version.dep.name
and req.specifier.contains(fix_version.dep.version)
and not req.specifier.contains(fix_version.version)
):
req.specifier = SpecifierSet(f"=={fix_version.version}")
assert req.marker is None or req.marker.evaluate()
f.write(str(req) + os.linesep)

def _recover_files(self, tmp_files: List[IO[str]]) -> None:
for (filename, tmp_file) in zip(self.filenames, tmp_files):
try:
os.replace(tmp_file.name, filename)
# We need to tinker with the internals to prevent the file wrapper from attempting
# to remove the temporary file like in the regular case.
tmp_file._closer.delete = False # type: ignore[attr-defined]
except Exception as e:
# Not much we can do at this point since we're already handling an exception. Just
# log the error and try to recover the rest of the files.
logger.warning(f"encountered an exception during file recovery: {e}")
continue


class RequirementSourceError(DependencySourceError):
"""A requirements-parsing specific `DependencySourceError`."""

pass


class RequirementFixError(DependencyFixError):
"""A requirements-fixing specific `DependencyFixError`."""

pass
16 changes: 16 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import tempfile
from pathlib import Path

import pytest
from packaging.version import Version
Expand Down Expand Up @@ -59,3 +61,17 @@ def cache_dir():
cache = tempfile.TemporaryDirectory()
yield cache.name
cache.cleanup()


@pytest.fixture
def req_file():
req_file = tempfile.NamedTemporaryFile(delete=False)
yield Path(req_file.name)
os.remove(req_file.name)


@pytest.fixture
def other_req_file():
req_file = tempfile.NamedTemporaryFile(delete=False)
yield Path(req_file.name)
os.remove(req_file.name)
woodruffw marked this conversation as resolved.
Show resolved Hide resolved
187 changes: 187 additions & 0 deletions test/dependency_source/test_requirement.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
from pathlib import Path
from typing import List

import pretend # type: ignore
import pytest
from packaging.requirements import Requirement
from packaging.version import Version
from pip_api import _parse_requirements

from pip_audit._dependency_source import (
DependencyFixError,
DependencyResolver,
DependencyResolverError,
DependencySourceError,
ResolveLibResolver,
requirement,
)
from pip_audit._fix import ResolvedFixVersion
from pip_audit._service import Dependency, ResolvedDependency


Expand Down Expand Up @@ -94,3 +98,186 @@ def test_requirement_source_duplicate_dependencies(monkeypatch):
# If the dependency list has duplicates, then converting to a set will reduce the length of the
# collection
assert len(specs) == len(set(specs))


def _check_fixes(
input_reqs: List[str],
expected_reqs: List[str],
req_paths: List[Path],
fixes: List[ResolvedFixVersion],
) -> None:
# Populate the requirements files
for (input_req, req_path) in zip(input_reqs, req_paths):
with open(req_path, "w") as f:
f.write(input_req)

source = requirement.RequirementSource(req_paths, ResolveLibResolver())
for fix in fixes:
source.fix(fix)

# Check the requirements files
for (expected_req, req_path) in zip(expected_reqs, req_paths):
with open(req_path, "r") as f:
assert expected_req == f.read()


def test_requirement_source_fix(req_file):
_check_fixes(
["flask==0.5\n"],
["flask==1.0\n"],
[req_file],
[
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
],
)


def test_requirement_source_fix_multiple_files(req_file, other_req_file):
_check_fixes(
["flask==0.5\n", "requests==1.0\nflask==0.5\n"],
["flask==1.0\n", "requests==1.0\nflask==1.0\n"],
[req_file, other_req_file],
[
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
],
)


def test_requirement_source_fix_specifier_match(req_file, other_req_file):
_check_fixes(
["flask<1.0\n", "requests==1.0\nflask<=0.6\n"],
["flask==1.0\n", "requests==1.0\nflask==1.0\n"],
[req_file, other_req_file],
[
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
],
)


def test_requirement_source_fix_specifier_no_match(req_file, other_req_file):
# In order to make a fix, the specifier must match the current version and NOT the resolved fix
# version. If the specifier matches both, we don't apply the fix since installing from the given
# requirements file would already install the fixed version.
_check_fixes(
["flask>=0.5\n", "requests==1.0\nflask<2.0\n"],
["flask>=0.5\n", "requests==1.0\nflask<2.0\n"],
[req_file, other_req_file],
[
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
],
)


def test_requirement_source_fix_marker(req_file, other_req_file):
woodruffw marked this conversation as resolved.
Show resolved Hide resolved
# `pip-api` automatically filters out requirements with markers that don't apply to the current
# environment
_check_fixes(
[
'flask<1.0; python_version > "2.7"\n',
'requests==1.0\nflask<=0.6; python_version <= "2.7"\n',
],
[
'flask==1.0; python_version > "2.7"\n',
"requests==1.0\n",
],
[req_file, other_req_file],
[
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
],
)


def test_requirement_source_fix_comments(req_file, other_req_file):
# `pip-api` automatically filters out comments
_check_fixes(
[
"# comment here\nflask==0.5\n",
"requests==1.0\n# another comment\nflask==0.5",
woodruffw marked this conversation as resolved.
Show resolved Hide resolved
],
["flask==1.0\n", "requests==1.0\nflask==1.0\n"],
[req_file, other_req_file],
[
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
],
)


def test_requirement_source_fix_parse_failure(monkeypatch, req_file, other_req_file):
logger = pretend.stub(warning=pretend.call_recorder(lambda s: None))
monkeypatch.setattr(requirement, "logger", logger)

# If `pip-api` encounters multiple of the same package in the requirements file, it will throw a
# parsing error
input_reqs = ["flask==0.5\n", "flask==0.5\nrequests==1.0\nflask==0.3\n"]
req_paths = [req_file, other_req_file]

# Populate the requirements files
for (input_req, req_path) in zip(input_reqs, req_paths):
with open(req_path, "w") as f:
f.write(input_req)

source = requirement.RequirementSource(req_paths, ResolveLibResolver())
with pytest.raises(DependencyFixError):
source.fix(
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
)
assert len(logger.warning.calls) == 1

# Check that the requirements files remain unchanged
# If we encounter a failure while applying a fix, the fix should be rolled back from all files
for (expected_req, req_path) in zip(input_reqs, req_paths):
with open(req_path, "r") as f:
assert expected_req == f.read()


def test_requirement_source_fix_rollback_failure(monkeypatch, req_file, other_req_file):
logger = pretend.stub(warning=pretend.call_recorder(lambda s: None))
monkeypatch.setattr(requirement, "logger", logger)

# If `pip-api` encounters multiple of the same package in the requirements file, it will throw a
# parsing error
input_reqs = ["flask==0.5\n", "flask==0.5\nrequests==1.0\nflask==0.3\n"]
req_paths = [req_file, other_req_file]

# Populate the requirements files
for (input_req, req_path) in zip(input_reqs, req_paths):
with open(req_path, "w") as f:
f.write(input_req)

# Simulate an error being raised during file recovery
def mock_replace(*_args, **_kwargs):
raise OSError

monkeypatch.setattr(os, "replace", mock_replace)

source = requirement.RequirementSource(req_paths, ResolveLibResolver())
with pytest.raises(DependencyFixError):
source.fix(
ResolvedFixVersion(
dep=ResolvedDependency(name="flask", version=Version("0.5")), version=Version("1.0")
)
)
# One for the parsing error and one for each file that we failed to rollback
assert len(logger.warning.calls) == 3

# We couldn't move the original requirements files back so we should expect a partially applied
# fix. The first requirements file contains the fix, while the second one doesn't since we were
# in the process of writing it out and didn't flush.
expected_reqs = ["flask==1.0\n", "flask==0.5\nrequests==1.0\nflask==0.3\n"]
for (expected_req, req_path) in zip(expected_reqs, req_paths):
with open(req_path, "r") as f:
assert expected_req == f.read()