From 44db71c0772e5ef5758c38d0e4e8ad9995946c80 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:14:49 -0800 Subject: [PATCH 01/18] implement additional cvcuda infra for all branches to avoid duplicate setup --- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 3 ++- .../transforms/v2/functional/__init__.py | 2 +- .../transforms/v2/functional/_augment.py | 11 ++++++++++- .../transforms/v2/functional/_color.py | 12 +++++++++++- .../transforms/v2/functional/_geometry.py | 19 +++++++++++++++++-- torchvision/transforms/v2/functional/_misc.py | 11 +++++++++-- .../transforms/v2/functional/_utils.py | 16 ++++++++++++++++ 8 files changed, 68 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..bec9ffcf714 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel +from .functional._utils import _get_kernel, is_cvcuda_tensor class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index bb6051b4e61..765a772fe41 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,7 +15,7 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, + is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..52181e4624b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a904d8d7cbd..7ce5bdc7b7e 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,4 +1,5 @@ import io +from typing import TYPE_CHECKING import PIL.Image @@ -8,7 +9,15 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index be254c0d63a..5be9c62902a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import PIL.Image import torch from torch.nn.functional import conv2d @@ -9,7 +11,15 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _register_kernel_internal +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..c029488001c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,22 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf263df046..0fa05a2113c 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -1,5 +1,5 @@ import math -from typing import Optional +from typing import Optional, TYPE_CHECKING import PIL.Image import torch @@ -13,7 +13,14 @@ from ._meta import _convert_bounding_box_format -from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor +from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor + +CVCUDA_AVAILABLE = _is_cvcuda_available() + +if TYPE_CHECKING: + import cvcuda # type: ignore[import-not-found] +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() # noqa: F811 def normalize( diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index ad1eddd258b..73fafaf7425 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -169,3 +169,19 @@ def _is_cvcuda_available(): return True except ImportError: return False + + +def is_cvcuda_tensor(inpt: Any) -> bool: + """ + Check if the input is a CVCUDA tensor. + + Args: + inpt: The input to check. + + Returns: + True if the input is a CV-CUDA tensor, False otherwise. + """ + if _is_cvcuda_available(): + cvcuda = _import_cvcuda() + return isinstance(inpt, cvcuda.Tensor) + return False From e3dd70022fa1c87aca7a9a98068b6e13e802a375 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 25 Nov 2025 09:26:19 -0800 Subject: [PATCH 02/18] update make_image_cvcuda to have default batch dim --- test/common_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..e7bae60c41b 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): From c035df1c6eaebcad25604f8c298a7d9eaf86864b Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:16:27 -0800 Subject: [PATCH 03/18] add stanardized setup to main for easier updating of PRs and branches --- test/common_utils.py | 21 ++++++++++++++-- test/test_transforms_v2.py | 2 +- torchvision/transforms/v2/_utils.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 24 +++++++++++++++++-- 4 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e7bae60c41b..3b889e93d2e 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -20,13 +20,15 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import io, tv_tensors from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image +from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available from torchvision.utils import _Image_fromarray IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -275,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -287,6 +300,11 @@ def __init__( if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): actual, expected = (to_image(input) for input in [actual, expected]) + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -401,7 +419,6 @@ def make_image_pil(*args, **kwargs): def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): - # explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4) return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..7eba65550da 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -21,6 +21,7 @@ import torchvision.transforms.v2 as transforms from common_utils import ( + assert_close, assert_equal, cache, cpu_and_cuda, @@ -41,7 +42,6 @@ ) from torch import nn -from torch.testing import assert_close from torch.utils._pytree import tree_flatten, tree_map from torch.utils.data import DataLoader, default_collate from torchvision import tv_tensors diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 765a772fe41..3fc33ce5964 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 6b8f19f12f4..ee562cb2aee 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) +def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: + # CV-CUDA tensor is always in NHWC layout + # get_dimensions is CHW + return [image.shape[3], image.shape[1], image.shape[2]] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + + def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): return get_num_channels_image(inpt) @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels +def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: + # CV-CUDA tensor is always in NHWC layout + # get_num_channels is C + return image.shape[3] + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + + def get_size(inpt: torch.Tensor) -> list[int]: if torch.jit.is_scripting(): return get_size_image(inpt) @@ -114,7 +134,7 @@ def _get_size_image_pil(image: PIL.Image.Image) -> list[int]: return [height, width] -def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def _get_size_cvcuda(image: "cvcuda.Tensor") -> list[int]: """Get size of `cvcuda.Tensor` with NHWC layout.""" hw = list(image.shape[-3:-1]) ndims = len(hw) @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: if CVCUDA_AVAILABLE: - _get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda) + _register_kernel_internal(get_size, cvcuda.Tensor)(_get_size_cvcuda) @_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False) From 98d7dfb2059eaf2c10c3f549ea45f1d27875134c Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:25:09 -0800 Subject: [PATCH 04/18] update is_cvcuda_tensor --- torchvision/transforms/v2/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 73fafaf7425..44b2edeaf2d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -181,7 +181,8 @@ def is_cvcuda_tensor(inpt: Any) -> bool: Returns: True if the input is a CV-CUDA tensor, False otherwise. """ - if _is_cvcuda_available(): + try: cvcuda = _import_cvcuda() return isinstance(inpt, cvcuda.Tensor) - return False + except ImportError: + return False From ddc116d13febdae1d53507bcde9f103a4c14eba7 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:37:03 -0800 Subject: [PATCH 05/18] add cvcuda to pil compatible to transforms by default --- test/test_transforms_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7eba65550da..87166477669 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,6 +25,7 @@ assert_equal, cache, cpu_and_cuda, + cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, From e51dc7eabd254261347245f4492892fd0944aae5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Tue, 2 Dec 2025 12:46:23 -0800 Subject: [PATCH 06/18] remove cvcuda from transform class --- torchvision/transforms/v2/_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index bec9ffcf714..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,7 +11,7 @@ from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel, is_cvcuda_tensor +from .functional._utils import _get_kernel class Transform(nn.Module): @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() From 4939355a2c7421eeba95d7f155fe7953066aec6d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:07:08 -0800 Subject: [PATCH 07/18] resolve more formatting naming --- torchvision/transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 52181e4624b..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index e8630f788ca..af03ad018d4 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -51,14 +51,14 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]: return get_dimensions_image(video) -def _get_dimensions_cvcuda(image: "cvcuda.Tensor") -> list[int]: +def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]: # CV-CUDA tensor is always in NHWC layout # get_dimensions is CHW return [image.shape[3], image.shape[1], image.shape[2]] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_dimensions, cvcuda.Tensor)(_get_dimensions_cvcuda) + _register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda) def get_num_channels(inpt: torch.Tensor) -> int: @@ -97,14 +97,14 @@ def get_num_channels_video(video: torch.Tensor) -> int: get_image_num_channels = get_num_channels -def _get_num_channels_cvcuda(image: "cvcuda.Tensor") -> int: +def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int: # CV-CUDA tensor is always in NHWC layout # get_num_channels is C return image.shape[3] if CVCUDA_AVAILABLE: - _register_kernel_internal(get_num_channels, cvcuda.Tensor)(_get_num_channels_cvcuda) + _register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda) def get_size(inpt: torch.Tensor) -> list[int]: From 1e864d86fbd8da71b554203373e738ff7c112bb4 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 17 Nov 2025 12:57:26 -0800 Subject: [PATCH 08/18] initial cvcuda normalize kernel implementation --- test/test_transforms_v2.py | 73 +++++++++++++++++++ .../transforms/v2/functional/__init__.py | 1 + torchvision/transforms/v2/functional/_misc.py | 35 +++++++++ 3 files changed, 109 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f767e211125..cee8333ee42 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5633,6 +5633,79 @@ def test_correctness_image(self, mean, std, dtype, fn): assert_equal(actual, expected) +@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") +@needs_cuda +class TestNormalizeCVCUDA: + MEANS_STDS = { + "RGB": TestNormalize.MEANS_STDS, + "GRAY": [([0.5], [2.0])], + } + MEAN_STD = { + "RGB": MEANS_STDS["RGB"][0], + "GRAY": MEANS_STDS["GRAY"][0], + } + + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) + def test_functional(self, color_space, batch_dims, dtype): + means_stds = self.MEANS_STDS[color_space] + for mean, std in means_stds: + image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims) + check_functional(F.normalize, image, mean=mean, std=std) + + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) + def test_functional_scalar(self, color_space, batch_dims, dtype): + image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims) + check_functional(F.normalize, image, mean=0.5, std=2.0) + + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) + @pytest.mark.parametrize("batch_dims", [(1,)]) + def test_functional_error(self, dtype, batch_dims): + rgb_mean, rgb_std = self.MEAN_STD["RGB"] + gray_mean, gray_std = self.MEAN_STD["GRAY"] + + with pytest.raises(ValueError, match="Inplace normalization is not supported for CVCUDA."): + F.normalize(make_image_cvcuda(batch_dims=batch_dims, dtype=dtype), mean=rgb_mean, std=rgb_std, inplace=True) + + with pytest.raises(ValueError, match="Mean should have 3 elements. Got 1."): + F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=gray_mean, std=rgb_std) + + with pytest.raises(ValueError, match="Std should have 3 elements. Got 1."): + F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=rgb_mean, std=gray_std) + + with pytest.raises(ValueError, match="Mean should have 1 elements. Got 3."): + F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=rgb_mean, std=gray_std) + + with pytest.raises(ValueError, match="Std should have 1 elements. Got 3."): + F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=gray_mean, std=rgb_std) + + @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) + @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) + def test_transform(self, dtype, color_space, batch_dims): + means_stds = self.MEANS_STDS[color_space] + for mean, std in means_stds: + check_transform( + transforms.Normalize(mean=mean, std=std), + make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims), + ) + + @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) + def test_correctness_image(self, batch_dims): + mean, std = self.MEAN_STD["RGB"] + torch_image = make_image(batch_dims=batch_dims, dtype=torch.float32, device="cuda") + cvc_image = F.to_cvcuda_tensor(torch_image) + + gold = F.normalize(torch_image, mean=mean, std=std) + image = F.normalize(cvc_image, mean=mean, std=std) + image = F.cvcuda_to_tensor(image) + + assert_close(image, gold, rtol=1e-7, atol=1e-7) + + class TestClampBoundingBoxes: @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("clamping_mode", ("soft", "hard", None)) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 032a993b1f0..617c87aeaf9 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -153,6 +153,7 @@ gaussian_noise_image, gaussian_noise_video, normalize, + normalize_cvcuda, normalize_image, normalize_video, sanitize_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 0fa05a2113c..e288105f469 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -79,6 +79,41 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in return normalize_image(video, mean, std, inplace=inplace) +def normalize_cvcuda( + image: "cvcuda.Tensor", + mean: Sequence[float | int] | float | int, + std: Sequence[float | int] | float | int, + inplace: bool = False, +) -> "cvcuda.Tensor": + if inplace: + raise ValueError("Inplace normalization is not supported for CVCUDA.") + + channels = image.shape[3] + if isinstance(mean, float | int): + mean = [mean] * channels + elif len(mean) != channels: + raise ValueError(f"Mean should have {channels} elements. Got {len(mean)}.") + if isinstance(std, float | int): + std = [std] * channels + elif len(std) != channels: + raise ValueError(f"Std should have {channels} elements. Got {len(std)}.") + + mean = torch.as_tensor(mean, dtype=torch.float32) + std = torch.as_tensor(std, dtype=torch.float32) + mean_tensor = mean.reshape(1, 1, 1, channels) + std_tensor = std.reshape(1, 1, 1, channels) + mean_tensor = mean_tensor.cuda() + std_tensor = std_tensor.cuda() + mean_cv = cvcuda.as_tensor(mean_tensor, cvcuda.TensorLayout.NHWC) + std_cv = cvcuda.as_tensor(std_tensor, cvcuda.TensorLayout.NHWC) + + return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV) + + +if CVCUDA_AVAILABLE: + _normalize_cvcuda = _register_kernel_internal(normalize, cvcuda.Tensor)(normalize_cvcuda) + + def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor: """See :class:`~torchvision.transforms.v2.GaussianBlur` for details.""" if torch.jit.is_scripting(): From 01efae7c0d928f176b59e6cb8819a3d1ce95107d Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 17 Nov 2025 13:13:02 -0800 Subject: [PATCH 09/18] add comment explaining mean/std behavior, one-line intermediate creation --- torchvision/transforms/v2/functional/_misc.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index e288105f469..60cf0e20026 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -98,14 +98,18 @@ def normalize_cvcuda( elif len(std) != channels: raise ValueError(f"Std should have {channels} elements. Got {len(std)}.") - mean = torch.as_tensor(mean, dtype=torch.float32) - std = torch.as_tensor(std, dtype=torch.float32) - mean_tensor = mean.reshape(1, 1, 1, channels) - std_tensor = std.reshape(1, 1, 1, channels) - mean_tensor = mean_tensor.cuda() - std_tensor = std_tensor.cuda() - mean_cv = cvcuda.as_tensor(mean_tensor, cvcuda.TensorLayout.NHWC) - std_cv = cvcuda.as_tensor(std_tensor, cvcuda.TensorLayout.NHWC) + # CV-CUDA requires float32 tensors for the mean/std parameters + # at small batchs, this is costly relative to normalize operation + # if CV-CUDA is known to be a backend, could optimize this + # For Normalize class: + # by creating tensors at class initialization time + # For functional API: + # by storing cached tensors in helper function with functools.lru_cache (would it even be worth it?) + # Since CV-CUDA is 1) not default backend, 2) only strictly faster at large batch size, ignore + mt = torch.as_tensor(mean, dtype=torch.float32).reshape(1, 1, 1, channels).cuda() + st = torch.as_tensor(std, dtype=torch.float32).reshape(1, 1, 1, channels).cuda() + mean_cv = cvcuda.as_tensor(mt, cvcuda.TensorLayout.NHWC) + std_cv = cvcuda.as_tensor(st, cvcuda.TensorLayout.NHWC) return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV) From 79ea0da680fd0bbd71d33f4b2ab6eb0d0fc37dc3 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 20 Nov 2025 11:23:03 -0800 Subject: [PATCH 10/18] fix: normalize_cvcuda move to correct patterns for tests/exporting --- test/test_transforms_v2.py | 111 ++++++------------ .../transforms/v2/functional/__init__.py | 1 - torchvision/transforms/v2/functional/_misc.py | 15 ++- 3 files changed, 48 insertions(+), 79 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index cee8333ee42..0c0bfe079c5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5570,7 +5570,17 @@ def test_kernel_image_inplace(self, device): def test_kernel_video(self): check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD) - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) def test_functional(self, make_input): check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD) @@ -5580,6 +5590,11 @@ def test_functional(self, make_input): (F.normalize_image, torch.Tensor), (F.normalize_image, tv_tensors.Image), (F.normalize_video, tv_tensors.Video), + pytest.param( + F._misc._normalize_cvcuda, + _import_cvcuda().Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), + ), ], ) def test_functional_signature(self, kernel, input_type): @@ -5608,7 +5623,17 @@ def _sample_input_adapter(self, transform, input, device): adapted_input[key] = value return adapted_input - @pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image, + make_video, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) def test_transform(self, make_input): check_transform( transforms.Normalize(mean=self.MEAN, std=self.STD), @@ -5632,78 +5657,16 @@ def test_correctness_image(self, mean, std, dtype, fn): assert_equal(actual, expected) - -@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") -@needs_cuda -class TestNormalizeCVCUDA: - MEANS_STDS = { - "RGB": TestNormalize.MEANS_STDS, - "GRAY": [([0.5], [2.0])], - } - MEAN_STD = { - "RGB": MEANS_STDS["RGB"][0], - "GRAY": MEANS_STDS["GRAY"][0], - } - - @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) - @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) - @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) - def test_functional(self, color_space, batch_dims, dtype): - means_stds = self.MEANS_STDS[color_space] - for mean, std in means_stds: - image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims) - check_functional(F.normalize, image, mean=mean, std=std) - - @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) - @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) - @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) - def test_functional_scalar(self, color_space, batch_dims, dtype): - image = make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims) - check_functional(F.normalize, image, mean=0.5, std=2.0) - - @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) - @pytest.mark.parametrize("batch_dims", [(1,)]) - def test_functional_error(self, dtype, batch_dims): - rgb_mean, rgb_std = self.MEAN_STD["RGB"] - gray_mean, gray_std = self.MEAN_STD["GRAY"] - - with pytest.raises(ValueError, match="Inplace normalization is not supported for CVCUDA."): - F.normalize(make_image_cvcuda(batch_dims=batch_dims, dtype=dtype), mean=rgb_mean, std=rgb_std, inplace=True) - - with pytest.raises(ValueError, match="Mean should have 3 elements. Got 1."): - F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=gray_mean, std=rgb_std) - - with pytest.raises(ValueError, match="Std should have 3 elements. Got 1."): - F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="RGB", dtype=dtype), mean=rgb_mean, std=gray_std) - - with pytest.raises(ValueError, match="Mean should have 1 elements. Got 3."): - F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=rgb_mean, std=gray_std) - - with pytest.raises(ValueError, match="Std should have 1 elements. Got 3."): - F.normalize(make_image_cvcuda(batch_dims=batch_dims, color_space="GRAY", dtype=dtype), mean=gray_mean, std=rgb_std) - - @pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32]) - @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) - @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) - def test_transform(self, dtype, color_space, batch_dims): - means_stds = self.MEANS_STDS[color_space] - for mean, std in means_stds: - check_transform( - transforms.Normalize(mean=mean, std=std), - make_image_cvcuda(color_space=color_space, dtype=dtype, batch_dims=batch_dims), - ) - - @pytest.mark.parametrize("batch_dims", [(1,), (2,), (4,)]) - def test_correctness_image(self, batch_dims): - mean, std = self.MEAN_STD["RGB"] - torch_image = make_image(batch_dims=batch_dims, dtype=torch.float32, device="cuda") - cvc_image = F.to_cvcuda_tensor(torch_image) - - gold = F.normalize(torch_image, mean=mean, std=std) - image = F.normalize(cvc_image, mean=mean, std=std) - image = F.cvcuda_to_tensor(image) - - assert_close(image, gold, rtol=1e-7, atol=1e-7) + @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + @pytest.mark.parametrize(("mean", "std"), MEANS_STDS) + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)]) + def test_correctness_cvcuda(self, mean, std, dtype, fn): + image = make_image(batch_dims=(1,), dtype=dtype, device="cuda") + cvc_image = F.to_cvcuda_tensor(image) + actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std) + expected = fn(image, mean=mean, std=std) + torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=1e-7, atol=1e-7) class TestClampBoundingBoxes: diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 617c87aeaf9..032a993b1f0 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -153,7 +153,6 @@ gaussian_noise_image, gaussian_noise_video, normalize, - normalize_cvcuda, normalize_image, normalize_video, sanitize_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 60cf0e20026..8dcca5c3d49 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -79,15 +79,22 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in return normalize_image(video, mean, std, inplace=inplace) -def normalize_cvcuda( +def _normalize_cvcuda( image: "cvcuda.Tensor", - mean: Sequence[float | int] | float | int, - std: Sequence[float | int] | float | int, + mean: list[float], + std: list[float], inplace: bool = False, ) -> "cvcuda.Tensor": + cvcuda = _import_cvcuda() if inplace: raise ValueError("Inplace normalization is not supported for CVCUDA.") + # CV-CUDA supports signed int and float tensors + # torchvision only supports uint and float, right now CV-CUDA doesnt expose float16, so only check 32 + # in the future add float16 once exposed in CV-CUDA + if not (image.dtype == cvcuda.Type.F32): + raise ValueError(f"Input tensor should be a float tensor. Got {image.dtype}.") + channels = image.shape[3] if isinstance(mean, float | int): mean = [mean] * channels @@ -115,7 +122,7 @@ def normalize_cvcuda( if CVCUDA_AVAILABLE: - _normalize_cvcuda = _register_kernel_internal(normalize, cvcuda.Tensor)(normalize_cvcuda) + _normalize_cvcuda_registered = _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda) def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor: From 429f77f9f16762874bcc8942cc4335f6f7e1496f Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 08:49:08 -0800 Subject: [PATCH 11/18] fix tests crashing before run without cvcuda --- test/test_transforms_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0c0bfe079c5..72a3166cd5f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5592,12 +5592,14 @@ def test_functional(self, make_input): (F.normalize_video, tv_tensors.Video), pytest.param( F._misc._normalize_cvcuda, - _import_cvcuda().Tensor, + "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): + if input_type == "cvcuda.Tensor": + input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.normalize, kernel=kernel, input_type=input_type) def test_functional_error(self): From 8ed3b267b11fae818a0e8a613f284858cec81c51 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 10:35:38 -0800 Subject: [PATCH 12/18] resolve more review comments --- torchvision/transforms/v2/functional/_misc.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 8dcca5c3d49..c64c379c8ce 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -105,14 +105,6 @@ def _normalize_cvcuda( elif len(std) != channels: raise ValueError(f"Std should have {channels} elements. Got {len(std)}.") - # CV-CUDA requires float32 tensors for the mean/std parameters - # at small batchs, this is costly relative to normalize operation - # if CV-CUDA is known to be a backend, could optimize this - # For Normalize class: - # by creating tensors at class initialization time - # For functional API: - # by storing cached tensors in helper function with functools.lru_cache (would it even be worth it?) - # Since CV-CUDA is 1) not default backend, 2) only strictly faster at large batch size, ignore mt = torch.as_tensor(mean, dtype=torch.float32).reshape(1, 1, 1, channels).cuda() st = torch.as_tensor(std, dtype=torch.float32).reshape(1, 1, 1, channels).cuda() mean_cv = cvcuda.as_tensor(mt, cvcuda.TensorLayout.NHWC) @@ -122,7 +114,7 @@ def _normalize_cvcuda( if CVCUDA_AVAILABLE: - _normalize_cvcuda_registered = _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda) + _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda) def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor: From 57ca083d76071b309b349a57bf8d5ee8214614a9 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 24 Nov 2025 10:43:28 -0800 Subject: [PATCH 13/18] remove extra parameterize for dtype --- test/test_transforms_v2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 72a3166cd5f..6130f89df26 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5661,10 +5661,9 @@ def test_correctness_image(self, mean, std, dtype, fn): @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") @pytest.mark.parametrize(("mean", "std"), MEANS_STDS) - @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)]) - def test_correctness_cvcuda(self, mean, std, dtype, fn): - image = make_image(batch_dims=(1,), dtype=dtype, device="cuda") + def test_correctness_cvcuda(self, mean, std, fn): + image = make_image(batch_dims=(1,), dtype=torch.float32, device="cuda") cvc_image = F.to_cvcuda_tensor(image) actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std) expected = fn(image, mean=mean, std=std) From 184e37947d664ebb5fa861c0ecb460e0a03a0221 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Wed, 26 Nov 2025 10:26:55 -0800 Subject: [PATCH 14/18] simplify normalize testing into single test parameterize on input creation --- test/test_transforms_v2.py | 38 ++++++++++++++++--------- torchvision/transforms/v2/_transform.py | 4 +-- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 6130f89df26..a5317282df5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5650,24 +5650,36 @@ def _reference_normalize_image(self, image, *, mean, std): @pytest.mark.parametrize(("mean", "std"), MEANS_STDS) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") + ), + ], + ) @pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)]) - def test_correctness_image(self, mean, std, dtype, fn): - image = make_image(dtype=dtype) + def test_correctness_image(self, mean, std, dtype, make_input, fn): + if make_input == make_image_cvcuda and dtype != torch.float32: + pytest.skip("CVCUDA only supports float32 for normalize") + + image = make_input(dtype=dtype) actual = fn(image, mean=mean, std=std) - expected = self._reference_normalize_image(image, mean=mean, std=std) - assert_equal(actual, expected) + if make_input == make_image_cvcuda: + image = F.cvcuda_to_tensor(image).to(device="cpu") + image = image.squeeze(0) + actual = F.cvcuda_to_tensor(actual).to(device="cpu") + actual = actual.squeeze(0) - @pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA") - @pytest.mark.parametrize(("mean", "std"), MEANS_STDS) - @pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)]) - def test_correctness_cvcuda(self, mean, std, fn): - image = make_image(batch_dims=(1,), dtype=torch.float32, device="cuda") - cvc_image = F.to_cvcuda_tensor(image) - actual = F._misc._normalize_cvcuda(cvc_image, mean=mean, std=std) - expected = fn(image, mean=mean, std=std) - torch.testing.assert_close(F.cvcuda_to_tensor(actual), expected, rtol=1e-7, atol=1e-7) + expected = self._reference_normalize_image(image, mean=mean, std=std) + + if make_input == make_image_cvcuda: + torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6) + else: + assert_equal(actual, expected) class TestClampBoundingBoxes: diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..091794b39a8 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors -from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) def __init__(self) -> None: super().__init__() From 995834a249ecd5ad3755028b394d23dd703707e4 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Mon, 1 Dec 2025 18:01:33 -0800 Subject: [PATCH 15/18] update normalize based on PR reviews --- test/common_utils.py | 17 +++++++++++++++++ test/test_transforms_v2.py | 27 ++++++++++++--------------- torchvision/transforms/v2/_misc.py | 3 +++ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index e3fa464b5ea..0d1f68542bf 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -28,6 +28,7 @@ IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" +CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -276,6 +277,17 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] +def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: + tensor = cvcuda_to_tensor(tensor) + if tensor.ndim != 4: + raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") + if tensor.shape[0] != 1: + raise ValueError( + f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." + ) + return tensor.squeeze(0).cpu() + + class ImagePair(TensorLikePair): def __init__( self, @@ -304,6 +316,11 @@ def __init__( expected = expected[0] expected = expected.cpu() + # handle check for CV-CUDA Tensors + if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): + # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image + actual = cvcuda_to_pil_compatible_tensor(actual) + super().__init__(actual, expected, **other_parameters) self.mae = mae diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a5317282df5..af6ccd1d901 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5553,17 +5553,17 @@ def test_kernel_image(self, mean, std, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_inplace(self, device): - input = make_image_tensor(dtype=torch.float32, device=device) - input_version = input._version + inpt = make_image_tensor(dtype=torch.float32, device=device) + input_version = inpt._version - output_out_of_place = F.normalize_image(input, mean=self.MEAN, std=self.STD) - assert output_out_of_place.data_ptr() != input.data_ptr() - assert output_out_of_place is not input + output_out_of_place = F.normalize_image(inpt, mean=self.MEAN, std=self.STD) + assert output_out_of_place.data_ptr() != inpt.data_ptr() + assert output_out_of_place is not inpt - output_inplace = F.normalize_image(input, mean=self.MEAN, std=self.STD, inplace=True) - assert output_inplace.data_ptr() == input.data_ptr() + output_inplace = F.normalize_image(inpt, mean=self.MEAN, std=self.STD, inplace=True) + assert output_inplace.data_ptr() == inpt.data_ptr() assert output_inplace._version > input_version - assert output_inplace is input + assert output_inplace is inpt assert_equal(output_inplace, output_out_of_place) @@ -5613,9 +5613,9 @@ def test_functional_error(self): with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"): F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std) - def _sample_input_adapter(self, transform, input, device): + def _sample_input_adapter(self, transform, inpt, device): adapted_input = {} - for key, value in input.items(): + for key, value in inpt.items(): if isinstance(value, PIL.Image.Image): # normalize doesn't support PIL images continue @@ -5669,15 +5669,12 @@ def test_correctness_image(self, mean, std, dtype, make_input, fn): actual = fn(image, mean=mean, std=std) if make_input == make_image_cvcuda: - image = F.cvcuda_to_tensor(image).to(device="cpu") - image = image.squeeze(0) - actual = F.cvcuda_to_tensor(actual).to(device="cpu") - actual = actual.squeeze(0) + image = cvcuda_to_pil_compatible_tensor(image) expected = self._reference_normalize_image(image, mean=mean, std=std) if make_input == make_image_cvcuda: - torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6) + assert_close(actual, expected, rtol=0, atol=1e-6) else: assert_equal(actual, expected) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 305149c87b1..bea5bdfa184 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -17,6 +17,7 @@ get_bounding_boxes, get_keypoints, has_any, + is_cvcuda_tensor, is_pure_tensor, ) @@ -160,6 +161,8 @@ class Normalize(Transform): _v1_transform_cls = _transforms.Normalize + _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): super().__init__() self.mean = list(mean) From 71053586581afe60bb91d25f4ca4466afd3f0ee5 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:03:25 -0800 Subject: [PATCH 16/18] update normalize with changes from main --- test/common_utils.py | 16 ---------------- test/test_transforms_v2.py | 5 ++--- torchvision/transforms/v2/_misc.py | 8 ++++++-- torchvision/transforms/v2/_transform.py | 4 ++-- torchvision/transforms/v2/_utils.py | 8 ++++---- torchvision/transforms/v2/functional/_augment.py | 11 +---------- torchvision/transforms/v2/functional/_color.py | 12 +----------- torchvision/transforms/v2/functional/_misc.py | 4 ++-- 8 files changed, 18 insertions(+), 50 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index 0d1f68542bf..61f6a82eacf 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -277,17 +277,6 @@ def combinations_grid(**kwargs): return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] -def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor: - tensor = cvcuda_to_tensor(tensor) - if tensor.ndim != 4: - raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.") - if tensor.shape[0] != 1: - raise ValueError( - f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}." - ) - return tensor.squeeze(0).cpu() - - class ImagePair(TensorLikePair): def __init__( self, @@ -316,11 +305,6 @@ def __init__( expected = expected[0] expected = expected.cpu() - # handle check for CV-CUDA Tensors - if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor): - # Use the PIL compatible tensor, so we can always compare with PIL.Image.Image - actual = cvcuda_to_pil_compatible_tensor(actual) - super().__init__(actual, expected, **other_parameters) self.mae = mae diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index af6ccd1d901..d98fdb6dad5 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -25,7 +25,6 @@ assert_equal, cache, cpu_and_cuda, - cvcuda_to_pil_compatible_tensor, freeze_rng_state, ignore_jit_no_profile_information_warning, make_bounding_boxes, @@ -5591,7 +5590,7 @@ def test_functional(self, make_input): (F.normalize_image, tv_tensors.Image), (F.normalize_video, tv_tensors.Video), pytest.param( - F._misc._normalize_cvcuda, + F._misc._normalize_image_cvcuda, "cvcuda.Tensor", marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), @@ -5669,7 +5668,7 @@ def test_correctness_image(self, mean, std, dtype, make_input, fn): actual = fn(image, mean=mean, std=std) if make_input == make_image_cvcuda: - image = cvcuda_to_pil_compatible_tensor(image) + image = F.cvcuda_to_tensor(image)[0].cpu() expected = self._reference_normalize_image(image, mean=mean, std=std) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index bea5bdfa184..f15a9e3c62a 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,6 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor from ._utils import ( _parse_labels_getter, @@ -17,11 +18,13 @@ get_bounding_boxes, get_keypoints, has_any, - is_cvcuda_tensor, is_pure_tensor, ) +CVCUDA_AVAILABLE = _is_cvcuda_available() + + # TODO: do we want/need to expose this? class Identity(Transform): def transform(self, inpt: Any, params: dict[str, Any]) -> Any: @@ -161,7 +164,8 @@ class Normalize(Transform): _v1_transform_cls = _transforms.Normalize - _transformed_types = Transform._transformed_types + (is_cvcuda_tensor,) + if CVCUDA_AVAILABLE: + _transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,) def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): super().__init__() diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 091794b39a8..ac84fcb6c82 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import tv_tensors -from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor +from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel @@ -23,7 +23,7 @@ class Transform(nn.Module): # Class attribute defining transformed types. Other types are passed-through without any transformation # We support both Types and callables that are able to do further checks on the type of the input. - _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor) + _transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image) def __init__(self) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3fc33ce5964..e803aa49c60 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -15,8 +15,8 @@ from torchvision._utils import sequence_to_str from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401 -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor -from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor +from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]: @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor)) + if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -207,7 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]: tv_tensors.Mask, tv_tensors.BoundingBoxes, tv_tensors.KeyPoints, - is_cvcuda_tensor, + _is_cvcuda_tensor, ), ) } diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 7ce5bdc7b7e..a904d8d7cbd 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -1,5 +1,4 @@ import io -from typing import TYPE_CHECKING import PIL.Image @@ -9,15 +8,7 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def erase( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 5be9c62902a..be254c0d63a 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -1,5 +1,3 @@ -from typing import TYPE_CHECKING - import PIL.Image import torch from torch.nn.functional import conv2d @@ -11,15 +9,7 @@ from ._misc import _num_value_bits, to_dtype_image from ._type_conversion import pil_to_tensor, to_pil_image -from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal - - -CVCUDA_AVAILABLE = _is_cvcuda_available() - -if TYPE_CHECKING: - import cvcuda # type: ignore[import-not-found] -if CVCUDA_AVAILABLE: - cvcuda = _import_cvcuda() # noqa: F811 +from ._utils import _get_kernel, _register_kernel_internal def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index c64c379c8ce..b55dc465456 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -79,7 +79,7 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in return normalize_image(video, mean, std, inplace=inplace) -def _normalize_cvcuda( +def _normalize_image_cvcuda( image: "cvcuda.Tensor", mean: list[float], std: list[float], @@ -114,7 +114,7 @@ def _normalize_cvcuda( if CVCUDA_AVAILABLE: - _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda) + _register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_image_cvcuda) def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor: From 0f8910e52420416aa9d2769b137d28c33d72c3d2 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 11:04:38 -0800 Subject: [PATCH 17/18] remove extra cvcuda_available add --- test/common_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/common_utils.py b/test/common_utils.py index 61f6a82eacf..e3fa464b5ea 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -28,7 +28,6 @@ IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" -CVCUDA_AVAILABLE = _is_cvcuda_available() CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." From 969dd3f157e31a380da39d572a633b01d583c4c3 Mon Sep 17 00:00:00 2001 From: Justin Davis Date: Thu, 4 Dec 2025 13:43:52 -0800 Subject: [PATCH 18/18] check input type on kernel for signature test --- test/test_transforms_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index d98fdb6dad5..de821b70469 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5591,13 +5591,13 @@ def test_functional(self, make_input): (F.normalize_video, tv_tensors.Video), pytest.param( F._misc._normalize_image_cvcuda, - "cvcuda.Tensor", + None, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"), ), ], ) def test_functional_signature(self, kernel, input_type): - if input_type == "cvcuda.Tensor": + if kernel is F._misc._normalize_image_cvcuda: input_type = _import_cvcuda().Tensor check_functional_kernel_signature_match(F.normalize, kernel=kernel, input_type=input_type)