Skip to content

Commit

Permalink
Allow decode_image to support paths (#8624)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Sep 3, 2024
1 parent c36025a commit d0ebeb5
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 59 deletions.
8 changes: 7 additions & 1 deletion docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
:toctree: generated/
:template: function.rst

read_image
decode_image
encode_jpeg
decode_jpeg
Expand All @@ -38,6 +37,13 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.

ImageReadMode

Obsolete decoding function:

.. autosummary::
:toctree: generated/
:template: class.rst

read_image


Video
Expand Down
16 changes: 8 additions & 8 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ Here is an example of how to use the pre-trained image classification models:

.. code:: python
from torchvision.io import read_image
from torchvision.io import decode_image
from torchvision.models import resnet50, ResNet50_Weights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
Expand Down Expand Up @@ -283,10 +283,10 @@ Here is an example of how to use the pre-trained quantized image classification

.. code:: python
from torchvision.io import read_image
from torchvision.io import decode_image
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = ResNet50_QuantizedWeights.DEFAULT
Expand Down Expand Up @@ -339,11 +339,11 @@ Here is an example of how to use the pre-trained semantic segmentation models:

.. code:: python
from torchvision.io.image import read_image
from torchvision.io.image import decode_image
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
from torchvision.transforms.functional import to_pil_image
img = read_image("gallery/assets/dog1.jpg")
img = decode_image("gallery/assets/dog1.jpg")
# Step 1: Initialize model with the best available weights
weights = FCN_ResNet50_Weights.DEFAULT
Expand Down Expand Up @@ -411,12 +411,12 @@ Here is an example of how to use the pre-trained object detection models:
.. code:: python
from torchvision.io.image import read_image
from torchvision.io.image import decode_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
# Step 1: Initialize model with the best available weights
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
Expand Down
10 changes: 5 additions & 5 deletions gallery/others/plot_repurposing_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def show(imgs):
# We will take images and masks from the `PenFudan Dataset <https://www.cis.upenn.edu/~jshi/ped_html/>`_.


from torchvision.io import read_image
from torchvision.io import decode_image

img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png")
mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png")
img = read_image(img_path)
mask = read_image(mask_path)
img = decode_image(img_path)
mask = decode_image(mask_path)


# %%
Expand Down Expand Up @@ -181,8 +181,8 @@ def __getitem__(self, idx):
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])

img = read_image(img_path)
mask = read_image(mask_path)
img = decode_image(img_path)
mask = decode_image(mask_path)

img = F.convert_image_dtype(img, dtype=torch.float)
mask = F.convert_image_dtype(mask, dtype=torch.float)
Expand Down
6 changes: 3 additions & 3 deletions gallery/others/plot_scripted_tensor_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.nn as nn

import torchvision.transforms as v1
from torchvision.io import read_image
from torchvision.io import decode_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)
Expand All @@ -39,8 +39,8 @@
# :class:`torch.nn.Sequential` instead of
# :class:`~torchvision.transforms.v2.Compose`:

dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg'))

transforms = torch.nn.Sequential(
v1.RandomCrop(224),
Expand Down
10 changes: 5 additions & 5 deletions gallery/others/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def show(imgs):
# image of dtype ``uint8`` as input.

from torchvision.utils import make_grid
from torchvision.io import read_image
from torchvision.io import decode_image
from pathlib import Path

dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
dog1_int = decode_image(str(Path('../assets') / 'dog1.jpg'))
dog2_int = decode_image(str(Path('../assets') / 'dog2.jpg'))
dog_list = [dog1_int, dog2_int]

grid = make_grid(dog_list)
Expand Down Expand Up @@ -362,9 +362,9 @@ def show(imgs):
#

from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from torchvision.io import read_image
from torchvision.io import decode_image

person_int = read_image(str(Path("../assets") / "person1.jpg"))
person_int = decode_image(str(Path("../assets") / "person1.jpg"))

weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()
Expand Down
4 changes: 2 additions & 2 deletions gallery/transforms/plot_transforms_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
plt.rcParams["savefig.bbox"] = 'tight'

from torchvision.transforms import v2
from torchvision.io import read_image
from torchvision.io import decode_image

torch.manual_seed(1)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = read_image(str(Path('../assets') / 'astronaut.jpg'))
img = decode_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")

# %%
Expand Down
10 changes: 5 additions & 5 deletions test/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
import torchvision
from torchvision.io import decode_jpeg, decode_webp, read_file, read_image
from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file
from torchvision.models import resnet50, ResNet50_Weights


Expand All @@ -21,13 +21,13 @@ def smoke_test_torchvision() -> None:


def smoke_test_torchvision_read_decode() -> None:
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
if img_jpg.shape != (3, 606, 517):
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
if img_png.shape != (4, 471, 354):
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
if img_webp.shape != (3, 100, 100):
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")

Expand All @@ -54,7 +54,7 @@ def smoke_test_compile() -> None:


def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
Expand Down
21 changes: 21 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,5 +1044,26 @@ def test_decode_heic(decode_fun, scripted):
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.parametrize("input_type", ("Path", "str", "tensor"))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_image_path(input_type, scripted):
# Check that decode_image can support not just tensors as input
path = next(get_images(IMAGE_ROOT, ".jpg"))
if input_type == "Path":
input = Path(path)
elif input_type == "str":
input = path
elif input_type == "tensor":
input = read_file(path)
else:
raise ValueError("Oops")

if scripted and input_type == "Path":
pytest.xfail(reason="Can't pass a Path when scripting")

decode_fun = torch.jit.script(decode_image) if scripted else decode_image
decode_fun(input)


if __name__ == "__main__":
pytest.main([__file__])
40 changes: 10 additions & 30 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):


def decode_image(
input: torch.Tensor,
input: Union[torch.Tensor, str],
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the
appropriate operation to decode the image into a Tensor.
"""Decode an image into a tensor.
Currently supported image formats are jpeg, png, gif and webp.
The values of the output tensor are in uint8 in [0, 255] for most cases.
Expand All @@ -295,8 +295,9 @@ def decode_image(
tensor.
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
image.
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
tensor is passed, it must be one dimensional uint8 tensor containing
the raw bytes of the image. Otherwise, this must be a path to the image file.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
Expand All @@ -309,6 +310,8 @@ def decode_image(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
if not isinstance(input, torch.Tensor):
input = read_file(str(input))
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
return output

Expand All @@ -318,30 +321,7 @@ def read_image(
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Reads a JPEG, PNG, WEBP, or GIF image into a Tensor.
The values of the output tensor are in uint8 in [0, 255] for most cases.
If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.
Args:
path (str or ``pathlib.Path``): path of the image.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes. Only applies to JPEG and PNG images.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Only applies to JPEG and PNG images. Default: False.
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
"""[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
data = read_file(path)
Expand Down

0 comments on commit d0ebeb5

Please sign in to comment.