Skip to content

Commit

Permalink
Add support for decoding 16bits png (#8524)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jul 15, 2024
1 parent 6344041 commit 33e47d8
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 79 deletions.
18 changes: 6 additions & 12 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import pytest
import requests
import torch
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as F
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
decode_gif,
decode_image,
decode_jpeg,
Expand Down Expand Up @@ -211,16 +210,11 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
img_pil = normalize_dimensions(img_pil)

if img_path.endswith("16.png"):
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path)
img_lpng = decode_fun(data, mode=mode)

img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32
# PIL converts 16 bits pngs in uint8
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
data = read_file(img_path)
img_lpng = decode_fun(data, mode=mode)
assert img_lpng.dtype == torch.uint16
# PIL converts 16 bits pngs to uint8
img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True)
else:
data = read_file(img_path)
img_lpng = decode_fun(data, mode=mode)
Expand Down
30 changes: 27 additions & 3 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,15 +2076,17 @@ def fn(value):
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor

return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device)
return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)

@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")

input = make_image(dtype=input_dtype, device=device)

Expand Down Expand Up @@ -2171,6 +2173,28 @@ def test_errors_warnings(self, make_input):
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype

def test_uint16(self):
# These checks are probably already covered above but since uint16 is a
# newly supported dtype, we want to be extra careful, hence this
# explicit test
img_uint16 = torch.randint(0, 65535, (256, 512), dtype=torch.uint16)

img_uint8 = F.to_dtype(img_uint16, torch.uint8, scale=True)
img_float32 = F.to_dtype(img_uint16, torch.float32, scale=True)
img_int32 = F.to_dtype(img_uint16, torch.int32, scale=True)

assert_equal(img_uint8, (img_uint16 / 256).to(torch.uint8))
assert_close(img_float32, (img_uint16 / 65535))

assert_close(F.to_dtype(img_float32, torch.uint16, scale=True), img_uint16, rtol=0, atol=1)
# Ideally we'd check against (img_uint16 & 0xFF00) but bitwise and isn't supported for it yet
# so we simulate it by scaling down and up again.
assert_equal(F.to_dtype(img_uint8, torch.uint16, scale=True), ((img_uint16 / 256).to(torch.uint16) * 256))
assert_equal(F.to_dtype(img_int32, torch.uint16, scale=True), img_uint16)

assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8)
assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2)


class TestAdjustBrightness:
_CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0]
Expand Down
3 changes: 1 addition & 2 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ torch::Tensor decode_image(
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
return decode_png(data, mode, apply_exif_orientation);
} else if (
memcmp(gif_signature_1, datap, 6) == 0 ||
memcmp(gif_signature_2, datap, 6) == 0) {
Expand Down
59 changes: 16 additions & 43 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using namespace exif_private;
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits,
bool apply_exif_orientation) {
TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
Expand All @@ -26,7 +25,6 @@ bool is_little_endian() {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
// Check that the input tensor dtype is uint8
Expand Down Expand Up @@ -99,12 +97,12 @@ torch::Tensor decode_png(
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

auto max_bit_depth = allow_16_bits ? 16 : 8;
auto err_msg = "At most " + std::to_string(max_bit_depth) +
"-bit PNG images are supported currently.";
if (bit_depth > max_bit_depth) {
if (bit_depth > 8 && bit_depth != 16) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, err_msg)
TORCH_CHECK(
false,
"bit depth of png image is " + std::to_string(bit_depth) +
". Only <=8 and 16 are supported.")
}

int channels = png_get_channels(png_ptr, info_ptr);
Expand Down Expand Up @@ -199,45 +197,20 @@ torch::Tensor decode_png(
}

auto num_pixels_per_row = width * channels;
auto is_16_bits = bit_depth == 16;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), channels},
bit_depth <= 8 ? torch::kU8 : torch::kI32);

if (bit_depth <= 8) {
auto t_ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<uint8_t, 3>().data();
}
} else {
// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
png_set_swap(png_ptr);
}
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();

// We create a tensor instead of malloc-ing for automatic memory management
auto tmp_buffer_tensor = torch::empty(
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
uint16_t* tmp_buffer =
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();

for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
// Now we copy the uint16 values into the int32 tensor.
for (size_t j = 0; j < num_pixels_per_row; ++j) {
t_ptr[j] = (int32_t)tmp_buffer[j];
}
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<int32_t, 3>().data();
is_16_bits ? at::kUInt16 : torch::kU8);
if (is_little_endian()) {
png_set_swap(png_ptr);
}
auto t_ptr = (uint8_t*)tensor.data_ptr();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1);
}
t_ptr = (uint8_t*)tensor.data_ptr();
}

int exif_orientation = -1;
Expand Down
1 change: 0 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_png.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace image {
C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool allow_16_bits = false,
bool apply_exif_orientation = false);

} // namespace image
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace image {
static auto registry =
torch::RegisterOperators()
.op("image::decode_gif", &decode_gif)
.op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor",
.op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from PIL import Image

from ..io.image import _read_png_16
from ..io.image import decode_png, read_file
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset

Expand Down Expand Up @@ -481,7 +481,7 @@ def _read_flo(file_name: str) -> np.ndarray:

def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:

flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
valid_flow_mask = valid_flow_mask.bool()
Expand Down
38 changes: 25 additions & 13 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ def decode_png(
) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
The values of the output tensor are in uint8 in [0, 255] for most cases. If
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
Expand All @@ -93,7 +99,7 @@ def decode_png(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_png)
output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
return output


Expand Down Expand Up @@ -144,7 +150,7 @@ def decode_jpeg(
) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Args:
Expand Down Expand Up @@ -248,8 +254,13 @@ def decode_image(
Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
The values of the output tensor are in uint8 in [0, 255] for most cases. If
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
Expand Down Expand Up @@ -277,8 +288,14 @@ def read_image(
) -> torch.Tensor:
"""
Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
The values of the output tensor are in uint8 in [0, 255] for most cases. If
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.
Args:
path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image.
Expand All @@ -298,11 +315,6 @@ def read_image(
return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)


def decode_gif(input: torch.Tensor) -> torch.Tensor:
"""
Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def _max_value(dtype: torch.dtype) -> int:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.uint16:
return 65535
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
Expand Down
14 changes: 12 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def _num_value_bits(dtype: torch.dtype) -> int:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.uint16:
return 16
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
Expand Down Expand Up @@ -293,10 +295,18 @@ def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale:
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)

# TODO: Remove if/else inner blocks once uint16 dtype supports bitwise shift operations.
shift_by = abs(num_value_bits_input - num_value_bits_output)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
if image.dtype == torch.uint16:
return (image / 2 ** (shift_by)).to(dtype)
else:
return image.bitwise_right_shift(shift_by).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
if dtype == torch.uint16:
return image.to(dtype) * 2 ** (shift_by)
else:
return image.to(dtype).bitwise_left_shift_(shift_by)


# We encourage users to use to_dtype() instead but we keep this for BC
Expand Down

0 comments on commit 33e47d8

Please sign in to comment.