Skip to content

Commit

Permalink
Add type stubs for common.py funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Sep 25, 2024
1 parent f324274 commit 020a0ce
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 36 deletions.
2 changes: 2 additions & 0 deletions av/filter/filter.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ class Filter:
inputs: tuple[FilterPad, ...]
outputs: tuple[FilterPad, ...]

def __init__(self, name: str) -> None: ...

filters_available: set[str]
56 changes: 37 additions & 19 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import datetime
import errno
import functools
import os
import types
from typing import TYPE_CHECKING
from unittest import TestCase as _Base

import numpy as np
Expand All @@ -16,14 +19,22 @@
except ImportError:
has_pillow = False

if TYPE_CHECKING:
from typing import Any, Callable, TypeVar

from PIL.Image import Image

T = TypeVar("T")


__all__ = ("fate_suite",)


is_windows = os.name == "nt"
skip_tests = frozenset(os.environ.get("PYAV_SKIP_TESTS", "").split(","))


def makedirs(path: str) -> None:
def safe_makedirs(path: str) -> None:
try:
os.makedirs(path)
except OSError as e:
Expand Down Expand Up @@ -61,22 +72,20 @@ def fate_png() -> str:
return fate_suite("png1/55c99e750a5fd6_50314226.png")


def sandboxed(*args, **kwargs) -> str:
do_makedirs = kwargs.pop("makedirs", True)
base = kwargs.pop("sandbox", None)
timed = kwargs.pop("timed", False)
if kwargs:
raise TypeError("extra kwargs: %s" % ", ".join(sorted(kwargs)))
path = os.path.join(_sandbox(timed=timed) if base is None else base, *args)
if do_makedirs:
makedirs(os.path.dirname(path))
def sandboxed(
*args: str, makedirs: bool = True, sandbox: str | None = None, timed: bool = False
) -> str:
path = os.path.join(_sandbox(timed) if sandbox is None else sandbox, *args)
if makedirs:
safe_makedirs(os.path.dirname(path))

return path


# Decorator for running a test in the sandbox directory
def run_in_sandbox(func):
def run_in_sandbox(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def _inner(self, *args, **kwargs):
def _inner(self: Any, *args: Any, **kwargs: Any) -> T:
current_dir = os.getcwd()
try:
os.chdir(self.sandbox)
Expand Down Expand Up @@ -104,13 +113,13 @@ def assertNdarraysEqual(a: np.ndarray, b: np.ndarray) -> None:
assert False, f"ndarrays contents differ\n{msg}"


def assertImagesAlmostEqual(a, b, epsilon=0.1):
def assertImagesAlmostEqual(a: Image, b: Image, epsilon: float = 0.1) -> None:
import PIL.ImageFilter as ImageFilter

assert a.size == b.size
a = a.filter(ImageFilter.BLUR).getdata()
b = b.filter(ImageFilter.BLUR).getdata()
for i, ax, bx in zip(range(len(a)), a, b):
for i, ax, bx in zip(range(len(a)), a, b): # type: ignore
diff = sum(abs(ac / 256 - bc / 256) for ac, bc in zip(ax, bx)) / 3
assert diff < epsilon, f"images differed by {diff} at index {i}; {ax} {bx}"

Expand All @@ -119,14 +128,23 @@ class TestCase(_Base):
@classmethod
def _sandbox(cls, timed: bool = True) -> str:
path = os.path.join(_sandbox(timed=timed), cls.__name__)
makedirs(path)
safe_makedirs(path)
return path

@property
def sandbox(self) -> str:
return self._sandbox(timed=True)

def sandboxed(self, *args, **kwargs) -> str:
kwargs.setdefault("sandbox", self.sandbox)
kwargs.setdefault("timed", True)
return sandboxed(*args, **kwargs)
def sandboxed(
self,
*args: str,
makedirs: bool = True,
timed: bool = True,
sandbox: str | None = None,
) -> str:
if sandbox is None:
return sandboxed(
*args, makedirs=makedirs, timed=timed, sandbox=self.sandbox
)
else:
return sandboxed(*args, makedirs=makedirs, timed=timed, sandbox=sandbox)
15 changes: 8 additions & 7 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ def generate_audio_frame(
return frame


def pull_until_blocked(graph):
frames = []
def pull_until_blocked(graph: Graph) -> list[av.VideoFrame]:
frames: list[av.VideoFrame] = []
while True:
try:
frames.append(graph.pull())
frames.append(graph.vpull())
except av.AVError as e:
if e.errno != errno.EAGAIN:
raise
return frames


class TestFilters(TestCase):
def test_filter_descriptor(self):
def test_filter_descriptor(self) -> None:
f = Filter("testsrc")
assert f.name == "testsrc"
assert f.description == "Generate test pattern."
Expand Down Expand Up @@ -86,24 +86,25 @@ def test_generator_graph(self):
if has_pillow:
frame.to_image().save(self.sandboxed("mandelbrot2.png"))

def test_auto_find_sink(self):
def test_auto_find_sink(self) -> None:
graph = Graph()
src = graph.add("testsrc")
src.link_to(graph.add("buffersink"))
graph.configure()

frame = graph.pull()
frame = graph.vpull()

if has_pillow:
frame.to_image().save(self.sandboxed("mandelbrot3.png"))

def test_delegate_sink(self):
def test_delegate_sink(self) -> None:
graph = Graph()
src = graph.add("testsrc")
src.link_to(graph.add("buffersink"))
graph.configure()

frame = src.pull()
assert isinstance(frame, av.VideoFrame)

if has_pillow:
frame.to_image().save(self.sandboxed("mandelbrot4.png"))
Expand Down
18 changes: 8 additions & 10 deletions tests/test_videoframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ def test_opaque() -> None:
assert type(frame.opaque) is tuple and len(frame.opaque) == 2


def test_invalid_pixel_format():
with pytest.raises(
ValueError, match="not a pixel format: '__unknown_pix_fmt'"
) as cm:
def test_invalid_pixel_format() -> None:
with pytest.raises(ValueError, match="not a pixel format: '__unknown_pix_fmt'"):
VideoFrame(640, 480, "__unknown_pix_fmt")


Expand Down Expand Up @@ -90,7 +88,7 @@ def test_yuv420p_planes() -> None:
assert frame.planes[i].buffer_size == 320 * 240


def test_yuv420p_planes_align():
def test_yuv420p_planes_align() -> None:
# If we request 8-byte alignment for a width which is not a multiple of 8,
# the line sizes are larger than the plane width.
frame = VideoFrame(318, 238, "yuv420p")
Expand All @@ -106,7 +104,7 @@ def test_yuv420p_planes_align():
assert frame.planes[i].buffer_size == 160 * 119


def test_rgb24_planes():
def test_rgb24_planes() -> None:
frame = VideoFrame(640, 480, "rgb24")
assert len(frame.planes) == 1
assert frame.planes[0].width == 640
Expand All @@ -115,7 +113,7 @@ def test_rgb24_planes():
assert frame.planes[0].buffer_size == 640 * 480 * 3


def test_memoryview_read():
def test_memoryview_read() -> None:
frame = VideoFrame(640, 480, "rgb24")
frame.planes[0].update(b"01234" + (b"x" * (640 * 480 * 3 - 5)))
mem = memoryview(frame.planes[0])
Expand All @@ -129,11 +127,11 @@ def test_memoryview_read():


class TestVideoFrameImage(TestCase):
def setUp(self):
def setUp(self) -> None:
if not has_pillow:
pytest.skip()

def test_roundtrip(self):
def test_roundtrip(self) -> None:
import PIL.Image as Image

image = Image.open(fate_png())
Expand All @@ -142,7 +140,7 @@ def test_roundtrip(self):
img.save(self.sandboxed("roundtrip-high.jpg"))
assertImagesAlmostEqual(image, img)

def test_to_image_rgb24(self):
def test_to_image_rgb24(self) -> None:
sizes = [(318, 238), (320, 240), (500, 500)]
for width, height in sizes:
frame = VideoFrame(width, height, format="rgb24")
Expand Down

0 comments on commit 020a0ce

Please sign in to comment.