Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

requirement, pypi: Add a --require-hashes flag #229

Merged
merged 11 commits into from
Feb 6, 2022
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ python -m pip_audit --help
usage: pip-audit [-h] [-V] [-l] [-r REQUIREMENTS] [-f FORMAT] [-s SERVICE]
[-d] [-S] [--desc [{on,off,auto}]] [--cache-dir CACHE_DIR]
[--progress-spinner {on,off}] [--timeout TIMEOUT]
[--path PATHS] [-v] [--fix]
[--path PATHS] [-v] [--fix] [--require-hashes]
audit the Python environment for dependencies with known vulnerabilities
Expand Down Expand Up @@ -115,6 +115,10 @@ optional arguments:
setting it to `debug` (default: False)
--fix automatically upgrade dependencies with known
vulnerabilities (default: False)
--require-hashes require a hash to check each requirement against, for
repeatable audits; this option is implied when any
package in a requirements file has a `--hash` option.
(default: False)
```
<!-- @end-pip-audit-help@ -->

Expand Down
15 changes: 14 additions & 1 deletion pip_audit/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,12 @@ def _parser() -> argparse.ArgumentParser:
action="store_true",
help="automatically upgrade dependencies with known vulnerabilities",
)
parser.add_argument(
"--require-hashes",
action="store_true",
help="require a hash to check each requirement against, for repeatable audits; this option "
"is implied when any package in a requirements file has a `--hash` option.",
)
return parser


Expand All @@ -262,6 +268,10 @@ def audit() -> None:
output_desc = args.desc.to_bool(args.format)
formatter = args.format.to_format(output_desc)

# The `--require-hashes` flag is only valid with requirements files
if args.require_hashes and args.requirements is None:
parser.error("The --require-hashes flag can only be used with --requirement (-r)")

with ExitStack() as stack:
actors = []
if args.progress_spinner:
Expand All @@ -272,7 +282,10 @@ def audit() -> None:
if args.requirements is not None:
req_files: List[Path] = [Path(req.name) for req in args.requirements]
source = RequirementSource(
req_files, ResolveLibResolver(args.timeout, args.cache_dir, state), state
req_files,
ResolveLibResolver(args.timeout, args.cache_dir, state),
args.require_hashes,
state,
)
else:
source = PipSource(local=args.local, paths=args.paths, state=state)
Expand Down
41 changes: 40 additions & 1 deletion pip_audit/_dependency_source/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@

import logging
import os
import re
import shutil
from contextlib import ExitStack
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import IO, Iterator, List, Set, cast
from typing import IO, Iterator, List, Set, Union, cast

from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import Version
from pip_api import parse_requirements
from pip_api._parse_requirements import Requirement as ParsedRequirement
from pip_api._parse_requirements import UnparsedRequirement
from pip_api.exceptions import PipError

from pip_audit._dependency_source import (
Expand All @@ -29,6 +33,8 @@

logger = logging.getLogger(__name__)

PINNED_SPECIFIER_RE = re.compile(r"==(?P<version>.+?)$", re.VERBOSE)


class RequirementSource(DependencySource):
"""
Expand All @@ -39,6 +45,7 @@ def __init__(
self,
filenames: List[Path],
resolver: DependencyResolver,
require_hashes: bool = False,
state: AuditState = AuditState(),
) -> None:
"""
Expand All @@ -52,6 +59,7 @@ def __init__(
"""
self.filenames = filenames
self.resolver = resolver
self.require_hashes = require_hashes
self.state = state

def collect(self) -> Iterator[Dependency]:
Expand All @@ -67,6 +75,19 @@ def collect(self) -> Iterator[Dependency]:
except PipError as pe:
raise RequirementSourceError("requirement parsing raised an error") from pe

# If we're requiring hashes, we skip dependency resolution and check that each
# requirement is accompanied by a hash and is pinned. Files that include hashes must
# explicitly list all transitive dependencies so assuming that the requirements file is
# valid and able to be installed with `-r`, we can skip dependency resolution.
#
# If at least one requirement has a hash, it implies that we require hashes for all
# requirements
if self.require_hashes or any(
isinstance(req, ParsedRequirement) and req.hashes for req in reqs.values()
):
yield from self._collect_hashed_deps(iter(reqs.values()))
continue

# Invoke the dependency resolver to turn requirements into dependencies
req_values: List[Requirement] = [Requirement(str(req)) for req in reqs.values()]
try:
Expand Down Expand Up @@ -153,6 +174,24 @@ def _recover_files(self, tmp_files: List[IO[str]]) -> None:
logger.warning(f"encountered an exception during file recovery: {e}")
continue

def _collect_hashed_deps(
self, reqs: Iterator[Union[ParsedRequirement, UnparsedRequirement]]
) -> Iterator[Dependency]:
for req in reqs:
req = cast(ParsedRequirement, req)
if not req.hashes:
raise RequirementSourceError(
f"requirement {req.name} does not contain a hash: {str(req)}"
)
if req.specifier is not None:
pinned_specifier_info = PINNED_SPECIFIER_RE.match(str(req.specifier))
if pinned_specifier_info is not None:
# Yield a dependency with the hash
pinned_version = pinned_specifier_info.group("version")
yield ResolvedDependency(req.name, Version(pinned_version), req.hashes)
continue
raise RequirementSourceError(f"requirement {req.name} is not pinned: {str(req)}")


class RequirementSourceError(DependencySourceError):
"""A requirements-parsing specific `DependencySourceError`."""
Expand Down
5 changes: 3 additions & 2 deletions pip_audit/_service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterator, List, Tuple
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Tuple

from packaging.utils import canonicalize_name
from packaging.version import Version
Expand Down Expand Up @@ -54,6 +54,7 @@ class ResolvedDependency(Dependency):
"""

version: Version
hashes: Dict[str, List[str]] = field(default_factory=dict, hash=False)


@dataclass(frozen=True)
Expand Down
24 changes: 24 additions & 0 deletions pip_audit/_service/pypi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ def query(self, spec: Dependency) -> Tuple[Dependency, List[VulnerabilityResult]
response_json = response.json()
results: List[VulnerabilityResult] = []

# If the dependency has a hash explicitly listed, check it against the PyPI data
if spec.hashes:
releases = response_json["releases"]
release = releases.get(str(spec.version))
if release is None:
raise ServiceError(
"Could not find release to compare hashes: "
f"{spec.canonical_name} ({spec.version})"
)
for hash_type, hash_values in spec.hashes.items():
for hash_value in hash_values:
found = False
for dist in release:
digests = dist["digests"]
pypi_hash = digests.get(hash_type)
if pypi_hash is not None and pypi_hash == hash_value:
found = True
break
if not found:
raise ServiceError(
f"Mismatched hash for {spec.canonical_name} ({spec.version}): listed "
f"{hash_value} of type {hash_type} could not be found in PyPI releases"
)

vulns = response_json.get("vulnerabilities")

# No `vulnerabilities` key means that there are no vulnerabilities for any version
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
platforms="any",
python_requires=">=3.7",
install_requires=[
"pip-api>=0.0.26",
"pip-api>=0.0.27",
"packaging>=21.0.0",
"progress>=1.6",
"resolvelib>=0.8.0",
Expand Down
67 changes: 67 additions & 0 deletions test/dependency_source/test_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,70 @@ def mock_replace(*_args, **_kwargs):
for (expected_req, req_path) in zip(expected_reqs, req_paths):
with open(req_path, "r") as f:
assert expected_req == f.read().strip()


def test_requirement_source_require_hashes(monkeypatch):
source = requirement.RequirementSource(
[Path("requirements.txt")], ResolveLibResolver(), require_hashes=True
)

monkeypatch.setattr(
_parse_requirements, "_read_file", lambda _: ["flask==2.0.1 --hash=sha256:flask-hash"]
)

# The hash should be populated in the resolved dependency. Additionally, the source should not
# calculate and resolve transitive dependencies since requirements files with hashes must
# explicitly list all dependencies.
specs = list(source.collect())
assert specs == [
ResolvedDependency("flask", Version("2.0.1"), hashes={"sha256": ["flask-hash"]})
]


def test_requirement_source_require_hashes_missing(monkeypatch):
source = requirement.RequirementSource(
[Path("requirements.txt")], ResolveLibResolver(), require_hashes=True
)

monkeypatch.setattr(
_parse_requirements,
"_read_file",
lambda _: ["flask==2.0.1"],
)

# All requirements must be hashed when collecting with `require-hashes`
with pytest.raises(DependencySourceError):
list(source.collect())


def test_requirement_source_require_hashes_inferred(monkeypatch):
source = requirement.RequirementSource([Path("requirements.txt")], ResolveLibResolver())

monkeypatch.setattr(
_parse_requirements,
"_read_file",
lambda _: ["flask==2.0.1 --hash=sha256:flask-hash\nrequests==1.0"],
)

# If at least one requirement is hashed, this infers `require-hashes`
with pytest.raises(DependencySourceError):
list(source.collect())


def test_requirement_source_require_hashes_unpinned(monkeypatch):
source = requirement.RequirementSource(
[Path("requirements.txt")], ResolveLibResolver(), require_hashes=True
)

monkeypatch.setattr(
_parse_requirements,
"_read_file",
lambda _: [
"flask==2.0.1 --hash=sha256:flask-hash\nrequests>=1.0 --hash=sha256:requests-hash"
],
)

# When hashed dependencies are provided, all dependencies must be explicitly pinned to an exact
# version number
with pytest.raises(DependencySourceError):
list(source.collect())
55 changes: 55 additions & 0 deletions test/service/test_pypi.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,58 @@ def test_pypi_skipped_dep(cache_dir):
assert dep in results
vulns = results[dep]
assert len(vulns) == 0


def test_pypi_hashed_dep(cache_dir):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency(
"flask",
Version("2.0.1"),
hashes={"sha256": ["a6209ca15eb63fc9385f38e452704113d679511d9574d09b2cf9183ae7d20dc9"]},
)
results = dict(pypi.query_all(iter([dep])))
assert len(results) == 1
assert dep in results
vulns = results[dep]
assert len(vulns) == 0


def test_pypi_hashed_dep_mismatch(cache_dir):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency(
"flask",
Version("2.0.1"),
hashes={"sha256": ["mismatched-hash"]},
)
with pytest.raises(service.ServiceError):
dict(pypi.query_all(iter([dep])))


def test_pypi_hashed_dep_no_release_data(cache_dir, monkeypatch):
def get_mock_response():
class MockResponse:
def raise_for_status(self):
pass

def json(self):
return {
"releases": {},
"vulnerabilities": [
{
"id": "VULN-0",
"details": "The first vulnerability",
"fixed_in": ["1.1"],
}
],
}

return MockResponse()

monkeypatch.setattr(
service.pypi, "caching_session", lambda _: get_mock_session(get_mock_response)
)

pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("foo", Version("1.0"), hashes={"sha256": ["package-hash"]})
with pytest.raises(service.ServiceError):
dict(pypi.query_all(iter([dep])))