Skip to content

Commit

Permalink
pip_audit, test: perform even more PyPI caching (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
woodruffw authored Dec 7, 2021
1 parent 0557f49 commit 74d3d7c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 27 deletions.
4 changes: 3 additions & 1 deletion pip_audit/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ def audit() -> None:
source: DependencySource
if args.requirements is not None:
req_files: List[Path] = [Path(req.name) for req in args.requirements]
source = RequirementSource(req_files, ResolveLibResolver(args.timeout, state), state)
source = RequirementSource(
req_files, ResolveLibResolver(args.timeout, args.cache_dir, state), state
)
else:
source = PipSource(local=args.local, paths=args.paths)

Expand Down
46 changes: 31 additions & 15 deletions pip_audit/_dependency_source/resolvelib/pypi_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

import html5lib
import requests
from cachecontrol import CacheControl # type: ignore
from packaging.requirements import Requirement
from packaging.specifiers import SpecifierSet
from packaging.utils import canonicalize_name, parse_sdist_filename, parse_wheel_filename
from packaging.version import Version
from resolvelib.providers import AbstractProvider

from pip_audit._cache import caching_session
from pip_audit._state import AuditState
from pip_audit._util import python_version
from pip_audit._virtual_env import VirtualEnv
Expand All @@ -48,6 +50,7 @@ def __init__(
url: str,
extras: Set[str],
is_wheel: bool,
session: CacheControl,
timeout: Optional[int] = None,
state: AuditState = AuditState(),
) -> None:
Expand All @@ -61,8 +64,9 @@ def __init__(
self.url = url
self.extras = extras
self.is_wheel = is_wheel
self.timeout = timeout
self.state = state
self._session = session
self._timeout = timeout
self._state = state

self._metadata: Optional[EmailMessage] = None
self._dependencies: Optional[List[Requirement]] = None
Expand All @@ -82,7 +86,7 @@ def metadata(self) -> EmailMessage:
"""

if self._metadata is None:
self.state.update_state(f"Fetching metadata for {self.name} ({self.version})")
self._state.update_state(f"Fetching metadata for {self.name} ({self.version})")

if self.is_wheel:
self._metadata = self._get_metadata_for_wheel()
Expand Down Expand Up @@ -119,9 +123,9 @@ def _get_metadata_for_wheel(self):
"""
Extracts the metadata for this candidate, if it's a wheel.
"""
data = requests.get(self.url, timeout=self.timeout).content
data = self._session.get(self.url, timeout=self._timeout).content

self.state.update_state(f"Extracting wheel for {self.name} ({self.version})")
self._state.update_state(f"Extracting wheel for {self.name} ({self.version})")

with ZipFile(BytesIO(data)) as z:
for n in z.namelist():
Expand All @@ -139,7 +143,7 @@ def _get_metadata_for_sdist(self):
Extracts the metadata for this candidate, if it's a source distribution.
"""

response: requests.Response = requests.get(self.url, timeout=self.timeout)
response: requests.Response = self._session.get(self.url, timeout=self._timeout)
response.raise_for_status()
sdist_data = response.content
metadata = EmailMessage()
Expand All @@ -148,16 +152,16 @@ def _get_metadata_for_sdist(self):
sdist = Path(pkg_dir) / self.filename.name
sdist.write_bytes(sdist_data)

self.state.update_state(
self._state.update_state(
f"Installing source distribution in isolated environment for {self.name} "
f"({self.version})"
)

with TemporaryDirectory() as ve_dir:
ve = VirtualEnv([str(sdist)], self.state)
ve = VirtualEnv([str(sdist)], self._state)
ve.create(ve_dir)

self.state.update_state(
self._state.update_state(
f"Querying installed packages for {self.name} ({self.version})"
)

Expand All @@ -169,11 +173,11 @@ def _get_metadata_for_sdist(self):


def get_project_from_pypi(
project, extras, timeout: Optional[int], state: AuditState
session, project, extras, timeout: Optional[int], state: AuditState
) -> Iterator[Candidate]:
"""Return candidates created from the project name and extras."""
url = "https://pypi.org/simple/{}".format(project)
response: requests.Response = requests.get(url, timeout=timeout)
response: requests.Response = session.get(url, timeout=timeout)
if response.status_code == 404:
raise PyPINotFoundError(f'Could not find project "{project}" on PyPI')
response.raise_for_status()
Expand Down Expand Up @@ -213,6 +217,7 @@ def get_project_from_pypi(
is_wheel=is_wheel,
timeout=timeout,
state=state,
session=session,
)
except Exception:
continue
Expand All @@ -224,16 +229,25 @@ class PyPIProvider(AbstractProvider):
the official Python Package Index.
"""

def __init__(self, timeout: Optional[int] = None, state: AuditState = AuditState()):
def __init__(
self,
timeout: Optional[int] = None,
cache_dir: Optional[Path] = None,
state: AuditState = AuditState(),
):
"""
Create a new `PyPIProvider`.
`timeout` is an optional argument to control how many seconds the component should wait for
responses to network requests.
`cache_dir` is an optional argument to override the default HTTP caching directory.
`state` is an `AuditState` to use for state callbacks.
"""
self.timeout = timeout
self.state = state
self.session = caching_session(cache_dir, use_pip=True)
self._state = state

def identify(self, requirement_or_candidate):
"""
Expand All @@ -251,7 +265,7 @@ def find_matches(self, identifier, requirements, incompatibilities):
"""
See `resolvelib.providers.AbstractProvider.find_matches`.
"""
self.state.update_state(f"Resolving {identifier}")
self._state.update_state(f"Resolving {identifier}")

requirements = list(requirements[identifier])

Expand All @@ -268,7 +282,9 @@ def find_matches(self, identifier, requirements, incompatibilities):
candidates = sorted(
[
candidate
for candidate in get_project_from_pypi(identifier, extras, self.timeout, self.state)
for candidate in get_project_from_pypi(
self.session, identifier, extras, self.timeout, self._state
)
if candidate.version not in bad_versions
and all(candidate.version in r.specifier for r in requirements)
],
Expand Down
13 changes: 11 additions & 2 deletions pip_audit/_dependency_source/resolvelib/resolvelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import logging
from pathlib import Path
from typing import List, Optional

from packaging.requirements import Requirement
Expand All @@ -25,13 +26,21 @@ class ResolveLibResolver(DependencyResolver):
backend dependency resolution strategy.
"""

def __init__(self, timeout: Optional[int] = None, state: AuditState = AuditState()) -> None:
def __init__(
self,
timeout: Optional[int] = None,
cache_dir: Optional[Path] = None,
state: AuditState = AuditState(),
) -> None:
"""
Create a new `ResolveLibResolver`.
`timeout` and `cache_dir` are optional arguments for HTTP timeouts
and caching, respectively.
`state` is an `AuditState` to use for state callbacks.
"""
self.provider = PyPIProvider(timeout, state)
self.provider = PyPIProvider(timeout, cache_dir, state)
self.reporter = BaseReporter()
self.resolver: Resolver = Resolver(self.provider, self.reporter)

Expand Down
42 changes: 33 additions & 9 deletions test/dependency_source/test_resolvelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,16 @@ def test_resolvelib_wheel_patched(monkeypatch):
"Flask-2.0.1-py3-none-any.whl</a><br/>"
)

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))
# monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))
monkeypatch.setattr(
pypi_provider.Candidate, "_get_metadata_for_wheel", lambda _: get_metadata_mock()
)

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_package_mock(data)
)

req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert req in resolved_deps
Expand All @@ -149,12 +153,15 @@ def test_resolvelib_sdist_patched(monkeypatch, suffix):
# everything works end-to-end.
data = f'<a href="https://example.com/Flask-2.0.1.{suffix}">Flask-2.0.1.{suffix}</a><br/>'

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))
monkeypatch.setattr(
pypi_provider.Candidate, "_get_metadata_for_sdist", lambda _: get_metadata_mock()
)

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_package_mock(data)
)

req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert req in resolved_deps
Expand All @@ -172,9 +179,11 @@ def test_resolvelib_wheel_python_version(monkeypatch):
'data-requires-python="&lt;=2.7">Flask-2.0.1-py3-none-any.whl</a><br/>'
)

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_package_mock(data)
)

req = Requirement("flask==2.0.1")
with pytest.raises(ResolutionImpossible):
dict(resolver.resolve_all(iter([req])))
Expand All @@ -190,12 +199,15 @@ def test_resolvelib_wheel_canonical_name_mismatch(monkeypatch):
"Mask-2.0.1-py3-none-any.whl</a><br/>"
)

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))
monkeypatch.setattr(
pypi_provider.Candidate, "_get_metadata_for_wheel", lambda _: get_metadata_mock()
)

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_package_mock(data)
)

req = Requirement("flask==2.0.1")
with pytest.raises(InconsistentCandidate):
dict(resolver.resolve_all(iter([req])))
Expand All @@ -211,12 +223,15 @@ def test_resolvelib_wheel_invalid_version(monkeypatch):
"Flask-INVALID.VERSION-py3-none-any.whl</a><br/>"
)

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))
monkeypatch.setattr(
pypi_provider.Candidate, "_get_metadata_for_wheel", lambda _: get_metadata_mock()
)

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_package_mock(data)
)

req = Requirement("flask==2.0.1")
with pytest.raises(ResolutionImpossible):
dict(resolver.resolve_all(iter([req])))
Expand All @@ -226,12 +241,15 @@ def test_resolvelib_sdist_invalid_suffix(monkeypatch):
# Give the sdist an invalid suffix like ".foo" and insure that it gets skipped.
data = '<a href="https://example.com/Flask-2.0.1.foo">Flask-2.0.1.foo</a><br/>'

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_package_mock(data))
monkeypatch.setattr(
pypi_provider.Candidate, "_get_metadata_for_wheel", lambda _: get_metadata_mock()
)

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_package_mock(data)
)

req = Requirement("flask==2.0.1")
with pytest.raises(ResolutionImpossible):
dict(resolver.resolve_all(iter([req])))
Expand All @@ -251,6 +269,10 @@ def raise_for_status(self):
monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_http_error_mock())

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_http_error_mock()
)

req = Requirement("flask==2.0.1")
with pytest.raises(resolvelib.ResolveLibResolverError):
dict(resolver.resolve_all(iter([req])))
Expand All @@ -264,9 +286,11 @@ def __init__(self):

return Doc()

monkeypatch.setattr(requests, "get", lambda _url, **kwargs: get_http_not_found_mock())

resolver = resolvelib.ResolveLibResolver()
monkeypatch.setattr(
resolver.provider.session, "get", lambda _url, **kwargs: get_http_not_found_mock()
)

req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert len(resolved_deps) == 1
Expand Down

0 comments on commit 74d3d7c

Please sign in to comment.