diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 54ed18394cd..4bb18cf6b48 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -350,6 +350,7 @@ Color v2.RGB v2.RandomGrayscale v2.GaussianBlur + v2.GaussianNoise v2.RandomInvert v2.RandomPosterize v2.RandomSolarize @@ -368,6 +369,7 @@ Functionals v2.functional.grayscale_to_rgb v2.functional.to_grayscale v2.functional.gaussian_blur + v2.functional.gaussian_noise v2.functional.invert v2.functional.posterize v2.functional.solarize diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b0c1659f253..8a47a589508 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -111,8 +111,10 @@ def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs): input = input.as_subclass(torch.Tensor) with ignore_jit_no_profile_information_warning(): - actual = kernel_scripted(input, *args, **kwargs) - expected = kernel(input, *args, **kwargs) + with freeze_rng_state(): + actual = kernel_scripted(input, *args, **kwargs) + with freeze_rng_state(): + expected = kernel(input, *args, **kwargs) assert_close(actual, expected, rtol=rtol, atol=atol) @@ -3238,6 +3240,78 @@ def test_functional_image_correctness(self, dimensions, kernel_size, sigma, dtyp torch.testing.assert_close(actual, expected, rtol=0, atol=1) +class TestGaussianNoise: + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_kernel(self, make_input): + check_kernel( + F.gaussian_noise, + make_input(dtype=torch.float32), + # This cannot pass because the noise on a batch in not per-image + check_batched_vs_unbatched=False, + ) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_functional(self, make_input): + check_functional(F.gaussian_noise, make_input(dtype=torch.float32)) + + @pytest.mark.parametrize( + ("kernel", "input_type"), + [ + (F.gaussian_noise, torch.Tensor), + (F.gaussian_noise_image, tv_tensors.Image), + (F.gaussian_noise_video, tv_tensors.Video), + ], + ) + def test_functional_signature(self, kernel, input_type): + check_functional_kernel_signature_match(F.gaussian_noise, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image, make_video], + ) + def test_transform(self, make_input): + def adapter(_, input, __): + # This transform doesn't support uint8 so we have to convert the auto-generated uint8 tensors to float32 + # Same for PIL images + for key, value in input.items(): + if isinstance(value, torch.Tensor) and not value.is_floating_point(): + input[key] = value.to(torch.float32) + if isinstance(value, PIL.Image.Image): + input[key] = F.pil_to_tensor(value).to(torch.float32) + return input + + check_transform(transforms.GaussianNoise(), make_input(dtype=torch.float32), check_sample_input=adapter) + + def test_bad_input(self): + with pytest.raises(ValueError, match="Gaussian Noise is not implemented for PIL images."): + F.gaussian_noise(make_image_pil()) + with pytest.raises(ValueError, match="Input tensor is expected to be in float dtype"): + F.gaussian_noise(make_image(dtype=torch.uint8)) + with pytest.raises(ValueError, match="sigma shouldn't be negative"): + F.gaussian_noise(make_image(dtype=torch.float32), sigma=-1) + + def test_clip(self): + img = make_image(dtype=torch.float32) + + out = F.gaussian_noise(img, mean=100, clip=False) + assert out.min() > 50 + + out = F.gaussian_noise(img, mean=100, clip=True) + assert (out == 1).all() + + out = F.gaussian_noise(img, mean=-100, clip=False) + assert out.min() < -50 + + out = F.gaussian_noise(img, mean=-100, clip=True) + assert (out == 0).all() + + class TestAutoAugmentTransforms: # These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling. # It's typically very hard to test the effect on some parameters without heavy mocking logic. diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 6dccb8a5b78..33d83f1fe3f 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -45,6 +45,7 @@ from ._misc import ( ConvertImageDtype, GaussianBlur, + GaussianNoise, Identity, Lambda, LinearTransformation, diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index ad2c08150cc..6d62539ccd7 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -205,6 +205,33 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) +class GaussianNoise(Transform): + """Add gaussian noise to images or videos. + + The input tensor is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + Each image or frame in a batch will be transformed independently i.e. the + noise added to each image will be different. + + The input tensor is also expected to be of float dtype in ``[0, 1]``. + This transform does not support PIL images. + + Args: + mean (float): Mean of the sampled normal distribution. Default is 0. + sigma (float): Standard deviation of the sampled normal distribution. Default is 0.1. + clip (bool, optional): Whether to clip the values in ``[0, 1]`` after adding noise. Default is True. + """ + + def __init__(self, mean: float = 0.0, sigma: float = 0.1, clip=True) -> None: + super().__init__() + self.mean = mean + self.sigma = sigma + self.clip = clip + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._call_kernel(F.gaussian_noise, inpt, mean=self.mean, sigma=self.sigma, clip=self.clip) + + class ToDtype(Transform): """Converts the input to a specific dtype, optionally scaling the values for images or videos. diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 4d4bbf2e86d..d5705d55c4b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -136,6 +136,9 @@ gaussian_blur, gaussian_blur_image, gaussian_blur_video, + gaussian_noise, + gaussian_noise_image, + gaussian_noise_video, normalize, normalize_image, normalize_video, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 12d064f6638..84b686d50f9 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -181,6 +181,44 @@ def gaussian_blur_video( return gaussian_blur_image(video, kernel_size, sigma) +def gaussian_noise(inpt: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + """See :class:`~torchvision.transforms.v2.GaussianNoise`""" + if torch.jit.is_scripting(): + return gaussian_noise_image(inpt, mean=mean, sigma=sigma) + + _log_api_usage_once(gaussian_noise) + + kernel = _get_kernel(gaussian_noise, type(inpt)) + return kernel(inpt, mean=mean, sigma=sigma, clip=clip) + + +@_register_kernel_internal(gaussian_noise, torch.Tensor) +@_register_kernel_internal(gaussian_noise, tv_tensors.Image) +def gaussian_noise_image(image: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + if not image.is_floating_point(): + raise ValueError(f"Input tensor is expected to be in float dtype, got dtype={image.dtype}") + if sigma < 0: + raise ValueError(f"sigma shouldn't be negative. Got {sigma}") + + noise = mean + torch.randn_like(image) * sigma + out = image + noise + if clip: + out = torch.clamp(out, 0, 1) + return out + + +@_register_kernel_internal(gaussian_noise, tv_tensors.Video) +def gaussian_noise_video(video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True) -> torch.Tensor: + return gaussian_noise_image(video, mean=mean, sigma=sigma, clip=clip) + + +@_register_kernel_internal(gaussian_noise, PIL.Image.Image) +def _gaussian_noise_pil( + video: torch.Tensor, mean: float = 0.0, sigma: float = 0.1, clip: bool = True +) -> PIL.Image.Image: + raise ValueError("Gaussian Noise is not implemented for PIL images.") + + def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: """See :func:`~torchvision.transforms.v2.ToDtype` for details.""" if torch.jit.is_scripting():