Skip to content

Commit

Permalink
test, Makefile: Run MyPy across unit tests (#158)
Browse files Browse the repository at this point in the history
* Makefile: Run MyPy on unit tests

* test: Fix MyPy errors

* Makefile: Remove accidental addition of --show-error-codes flag

* test: blacken

* test: fix import

* test: type fixups

Co-authored-by: William Woodruff <[email protected]>
Co-authored-by: William Woodruff <[email protected]>
  • Loading branch information
3 people committed Dec 3, 2021
1 parent e5bcac6 commit 2bcea96
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ lint:
black $(ALL_PY_SRCS) && \
isort $(ALL_PY_SRCS) && \
flake8 $(ALL_PY_SRCS) && \
mypy $(PY_MODULE) && \
mypy $(PY_MODULE) test/ && \
interrogate -c pyproject.toml . && \
git diff --exit-code

Expand Down
2 changes: 1 addition & 1 deletion test/dependency_source/test_pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict

import pip_api
import pretend
import pretend # type: ignore
import pytest
from packaging.version import Version

Expand Down
36 changes: 18 additions & 18 deletions test/dependency_source/test_resolvelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pip_audit._dependency_source import resolvelib
from pip_audit._dependency_source.resolvelib import pypi_provider
from pip_audit._service.interface import Dependency, ResolvedDependency, SkippedDependency
from pip_audit._service.interface import ResolvedDependency, SkippedDependency


def get_package_mock(data):
Expand All @@ -29,7 +29,7 @@ def get_metadata_mock():
return EmailMessage()


def check_deps(resolved_deps: List[Dependency], expected_deps: List[Dependency]):
def check_deps(resolved_deps: List[ResolvedDependency], expected_deps: List[ResolvedDependency]):
# We don't want to just check that the two lists are equal because:
# - Some packages install additional dependencies for specific versions of Python. It's only
# practical to check that the resolved dependencies contain all the expected ones (but there
Expand All @@ -50,7 +50,7 @@ def check_deps(resolved_deps: List[Dependency], expected_deps: List[Dependency])
def test_resolvelib():
resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert len(resolved_deps) == 1
expected_deps = [
ResolvedDependency("flask", Version("2.0.1")),
Expand All @@ -63,15 +63,15 @@ def test_resolvelib():
assert req in resolved_deps
# Earlier Python versions have some extra dependencies. To avoid conditionals here, let's just
# check that the dependencies we specify are a subset.
check_deps(resolved_deps[req], expected_deps)
check_deps(resolved_deps[req], expected_deps) # type: ignore


def test_resolvelib_extras():
resolver = resolvelib.ResolveLibResolver()

# First check the dependencies without extras and as a basis for comparison
req = Requirement("requests>=2.8.1")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert len(resolved_deps) == 1
expected_deps = [
ResolvedDependency("requests", Version("2.26.0")),
Expand All @@ -81,11 +81,11 @@ def test_resolvelib_extras():
ResolvedDependency("urllib3", Version("1.26.7")),
]
assert req in resolved_deps
check_deps(resolved_deps[req], expected_deps)
check_deps(resolved_deps[req], expected_deps) # type: ignore

# Check that using the `socks` and `use_chardet_on_py3` extras pulls in additional dependencies
req = Requirement("requests[socks,use_chardet_on_py3]>=2.8.1")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert len(resolved_deps) == 1
expected_deps.extend(
[
Expand All @@ -94,13 +94,13 @@ def test_resolvelib_extras():
]
)
assert req in resolved_deps
check_deps(resolved_deps[req], expected_deps)
check_deps(resolved_deps[req], expected_deps) # type: ignore


def test_resolvelib_sdist():
resolver = resolvelib.ResolveLibResolver()
req = Requirement("ansible-core==2.11.5")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert len(resolved_deps) == 1
expected_deps = [
ResolvedDependency("ansible-core", Version("2.11.5")),
Expand All @@ -115,7 +115,7 @@ def test_resolvelib_sdist():
ResolvedDependency("markupsafe", Version("2.0.1")),
]
assert req in resolved_deps
check_deps(resolved_deps[req], expected_deps)
check_deps(resolved_deps[req], expected_deps) # type: ignore


def test_resolvelib_wheel_patched(monkeypatch):
Expand All @@ -136,7 +136,7 @@ def test_resolvelib_wheel_patched(monkeypatch):

resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert req in resolved_deps
assert resolved_deps[req] == [ResolvedDependency("flask", Version("2.0.1"))]

Expand All @@ -156,7 +156,7 @@ def test_resolvelib_sdist_patched(monkeypatch, suffix):

resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert req in resolved_deps
assert resolved_deps[req] == [ResolvedDependency("flask", Version("2.0.1"))]

Expand All @@ -177,7 +177,7 @@ def test_resolvelib_wheel_python_version(monkeypatch):
resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
with pytest.raises(ResolutionImpossible):
dict(resolver.resolve_all([req]))
dict(resolver.resolve_all(iter([req])))


def test_resolvelib_wheel_canonical_name_mismatch(monkeypatch):
Expand All @@ -198,7 +198,7 @@ def test_resolvelib_wheel_canonical_name_mismatch(monkeypatch):
resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
with pytest.raises(InconsistentCandidate):
dict(resolver.resolve_all([req]))
dict(resolver.resolve_all(iter([req])))


def test_resolvelib_wheel_invalid_version(monkeypatch):
Expand All @@ -219,7 +219,7 @@ def test_resolvelib_wheel_invalid_version(monkeypatch):
resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
with pytest.raises(ResolutionImpossible):
dict(resolver.resolve_all([req]))
dict(resolver.resolve_all(iter([req])))


def test_resolvelib_sdist_invalid_suffix(monkeypatch):
Expand All @@ -234,7 +234,7 @@ def test_resolvelib_sdist_invalid_suffix(monkeypatch):
resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
with pytest.raises(ResolutionImpossible):
dict(resolver.resolve_all([req]))
dict(resolver.resolve_all(iter([req])))


def test_resolvelib_http_error(monkeypatch):
Expand All @@ -253,7 +253,7 @@ def raise_for_status(self):
resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
with pytest.raises(resolvelib.ResolveLibResolverError):
dict(resolver.resolve_all([req]))
dict(resolver.resolve_all(iter([req])))


def test_resolvelib_http_notfound(monkeypatch):
Expand All @@ -268,7 +268,7 @@ def __init__(self):

resolver = resolvelib.ResolveLibResolver()
req = Requirement("flask==2.0.1")
resolved_deps = dict(resolver.resolve_all([req]))
resolved_deps = dict(resolver.resolve_all(iter([req])))
assert len(resolved_deps) == 1
expected_deps = [
SkippedDependency(name="flask", skip_reason='Could not find project "flask" on PyPI')
Expand Down
6 changes: 3 additions & 3 deletions test/format/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pip_audit._service as service

_TEST_VULN_DATA: Dict[service.Dependency, List[service.VulnerabilityResult]] = {
service.ResolvedDependency(name="foo", version="1.0"): [
service.ResolvedDependency(name="foo", version=Version("1.0")): [
service.VulnerabilityResult(
id="VULN-0",
description="The first vulnerability",
Expand All @@ -21,7 +21,7 @@
fix_versions=[Version("1.0")],
),
],
service.ResolvedDependency(name="bar", version="0.1"): [
service.ResolvedDependency(name="bar", version=Version("0.1")): [
service.VulnerabilityResult(
id="VULN-2",
description="The third vulnerability",
Expand All @@ -31,7 +31,7 @@
}

_TEST_VULN_DATA_SKIPPED_DEP: Dict[service.Dependency, List[service.VulnerabilityResult]] = {
service.ResolvedDependency(name="foo", version="1.0"): [
service.ResolvedDependency(name="foo", version=Version("1.0")): [
service.VulnerabilityResult(
id="VULN-0",
description="The first vulnerability",
Expand Down
14 changes: 7 additions & 7 deletions test/service/test_osv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_osv(self):
osv = service.OsvService()
dep = service.ResolvedDependency("jinja2", Version("2.4.1"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
osv.query_all([dep])
osv.query_all(iter([dep]))
)
self.assertEqual(len(results), 1)
self.assertTrue(dep in results)
Expand All @@ -25,7 +25,7 @@ def test_osv_uses_canonical_package_name(self):
# that our adapter is canonicalizing any dependencies passed into it.
osv = service.OsvService()
dep = service.ResolvedDependency("PyYAML", Version("5.3"))
results: List[service.VulnerabilityResult] = osv.query(dep)
_, results = osv.query(dep)

self.assertGreater(len(results), 0)

Expand All @@ -35,7 +35,7 @@ def test_osv_version_ranges(self):
osv = service.OsvService()
dep = service.ResolvedDependency("ansible", Version("2.8.0"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
osv.query_all([dep])
osv.query_all(iter([dep]))
)
self.assertEqual(len(results), 1)
self.assertTrue(dep in results)
Expand All @@ -49,7 +49,7 @@ def test_osv_multiple_pkg(self):
service.ResolvedDependency("flask", Version("0.5")),
]
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
osv.query_all(deps)
osv.query_all(iter(deps))
)
self.assertEqual(len(results), 2)
self.assertTrue(deps[0] in results and deps[1] in results)
Expand All @@ -60,7 +60,7 @@ def test_osv_no_vuln(self):
osv = service.OsvService()
dep = service.ResolvedDependency("foo", Version("1.0.0"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
osv.query_all([dep])
osv.query_all(iter([dep]))
)
self.assertEqual(len(results), 1)
self.assertTrue(dep in results)
Expand All @@ -78,13 +78,13 @@ def raise_for_status(self):
def test_osv_error_response(self, mock_post):
osv = service.OsvService()
dep = service.ResolvedDependency("jinja2", Version("2.4.1"))
self.assertRaises(service.ServiceError, lambda: dict(osv.query_all([dep])))
self.assertRaises(service.ServiceError, lambda: dict(osv.query_all(iter([dep]))))

def test_osv_skipped_dep(self):
osv = service.OsvService()
dep = service.SkippedDependency(name="foo", skip_reason="skip-reason")
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
osv.query_all([dep])
osv.query_all(iter([dep]))
)
self.assertEqual(len(results), 1)
self.assertTrue(dep in results)
Expand Down
23 changes: 12 additions & 11 deletions test/service/test_pypi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Dict, List

import pretend
import pretend # type: ignore
import pytest
import requests
from packaging.version import Version
Expand All @@ -24,7 +25,7 @@ def test_pypi(cache_dir):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("jinja2", Version("2.4.1"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
pypi.query_all([dep])
pypi.query_all(iter([dep]))
)
assert len(results) == 1
assert dep in results
Expand All @@ -39,7 +40,7 @@ def test_pypi_multiple_pkg(cache_dir):
service.ResolvedDependency("flask", Version("0.5")),
]
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
pypi.query_all(deps)
pypi.query_all(iter(deps))
)
assert len(results) == 2
assert deps[0] in results and deps[1] in results
Expand Down Expand Up @@ -70,7 +71,7 @@ def raise_for_status(self):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("jinja2", Version("2.4.1"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
pypi.query_all([dep])
pypi.query_all(iter([dep]))
)
assert len(results) == 1
skipped_dep = service.SkippedDependency(
Expand Down Expand Up @@ -102,7 +103,7 @@ def raise_for_status(self):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("jinja2", Version("2.4.1"))
with pytest.raises(service.ServiceError):
dict(pypi.query_all([dep]))
dict(pypi.query_all(iter([dep])))


def test_pypi_mocked_response(monkeypatch, cache_dir):
Expand Down Expand Up @@ -131,7 +132,7 @@ def json(self):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("foo", Version("1.0"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
pypi.query_all([dep])
pypi.query_all(iter([dep]))
)
assert len(results) == 1
assert dep in results
Expand Down Expand Up @@ -161,7 +162,7 @@ def json(self):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("foo", Version("1.0"))
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
pypi.query_all([dep])
pypi.query_all(iter([dep]))
)
assert len(results) == 1
assert dep in results
Expand Down Expand Up @@ -194,7 +195,7 @@ def json(self):
pypi = service.PyPIService(cache_dir)
dep = service.ResolvedDependency("foo", Version("1.0"))
with pytest.raises(service.ServiceError):
dict(pypi.query_all([dep]))
dict(pypi.query_all(iter([dep])))


def test_pypi_warns_about_old_pip(monkeypatch, cache_dir):
Expand All @@ -215,7 +216,7 @@ def test_pypi_warns_about_old_pip(monkeypatch, cache_dir):

def test_pypi_cache_dir(monkeypatch):
# When we supply a cache directory, always use that
cache_dir: str = _get_cache_dir("/tmp/foo/cache_dir")
cache_dir: str = _get_cache_dir(Path("/tmp/foo/cache_dir"))
assert cache_dir == "/tmp/foo/cache_dir"

def run_mock(args, **kwargs):
Expand All @@ -237,7 +238,7 @@ def test_pypi_cache_dir_old_pip(monkeypatch):
monkeypatch.setattr(service.pypi, "_PIP_VERSION", Version("1.0.0"))

# When we supply a cache directory, always use that
cache_dir: str = _get_cache_dir("/tmp/foo/cache_dir")
cache_dir: str = _get_cache_dir(Path("/tmp/foo/cache_dir"))
assert cache_dir == "/tmp/foo/cache_dir"

# Mock out home since this is going to be different across systems
Expand All @@ -258,7 +259,7 @@ def test_pypi_skipped_dep(cache_dir):
pypi = service.PyPIService(cache_dir)
dep = service.SkippedDependency(name="foo", skip_reason="skip-reason")
results: Dict[service.Dependency, List[service.VulnerabilityResult]] = dict(
pypi.query_all([dep])
pypi.query_all(iter([dep]))
)
assert len(results) == 1
assert dep in results
Expand Down
2 changes: 1 addition & 1 deletion test/test_audit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pretend
import pretend # type: ignore
import pytest
from packaging.version import Version

Expand Down

0 comments on commit 2bcea96

Please sign in to comment.