Skip to content
33 changes: 27 additions & 6 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
from torchvision.utils import _Image_fromarray


Expand Down Expand Up @@ -284,8 +285,24 @@ def __init__(
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])
# Convert PIL images to tv_tensors.Image (regardless of what the other is)
if isinstance(actual, PIL.Image.Image):
actual = to_image(actual)
if isinstance(expected, PIL.Image.Image):
expected = to_image(expected)

if _is_cvcuda_available():
if _is_cvcuda_tensor(actual):
actual = cvcuda_to_tensor(actual)
# Remove batch dimension if it's 1 for easier comparison against 3D PIL images
if actual.shape[0] == 1:
actual = actual[0]
actual = actual.cpu()
if _is_cvcuda_tensor(expected):
expected = cvcuda_to_tensor(expected)
if expected.shape[0] == 1:
expected = expected[0]
expected = expected.cpu()

super().__init__(actual, expected, **other_parameters)
self.mae = mae
Expand Down Expand Up @@ -400,8 +417,8 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down Expand Up @@ -541,5 +558,9 @@ def ignore_jit_no_profile_information_warning():
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
warnings.filterwarnings(
"ignore",
message=re.escape("operator() profile_node %"),
category=UserWarning,
)
yield
88 changes: 70 additions & 18 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1255,13 +1259,20 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._horizontal_flip_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
(F.horizontal_flip_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._horizontal_flip_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -1270,6 +1281,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1283,13 +1298,23 @@ def test_transform(self, make_input, device):
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()
expected = F.horizontal_flip(F.to_pil_image(image))
assert_equal(actual, expected)

def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1345,6 +1370,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1354,11 +1383,8 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomHorizontalFlip(p=0)

output = transform(input)

assert_equal(output, input)


Expand Down Expand Up @@ -1856,6 +1882,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1871,13 +1901,20 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._vertical_flip_image_cvcuda,
None,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
(F.vertical_flip_keypoints, tv_tensors.KeyPoints),
],
)
def test_functional_signature(self, kernel, input_type):
if kernel is F._geometry._vertical_flip_image_cvcuda:
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
Expand All @@ -1886,6 +1923,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1897,13 +1938,23 @@ def test_transform(self, make_input, device):
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))

@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0].cpu()
expected = F.vertical_flip(F.to_pil_image(image))
assert_equal(actual, expected)

def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1955,6 +2006,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1964,11 +2019,8 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomVerticalFlip(p=0)

output = transform(input)

assert_equal(output, input)


Expand Down
10 changes: 9 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._utils import _FillType
from torchvision.transforms.v2.functional._utils import _FillType, _is_cvcuda_available, _is_cvcuda_tensor

from ._transform import _RandomApplyTransform
from ._utils import (
Expand All @@ -30,6 +30,8 @@
query_size,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()


class RandomHorizontalFlip(_RandomApplyTransform):
"""Horizontally flip the input with a given probability.
Expand All @@ -45,6 +47,9 @@ class RandomHorizontalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomHorizontalFlip

if CVCUDA_AVAILABLE:
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.horizontal_flip, inpt)

Expand All @@ -63,6 +68,9 @@ class RandomVerticalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomVerticalFlip

if CVCUDA_AVAILABLE:
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.vertical_flip, inpt)

Expand Down
31 changes: 29 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import PIL.Image
import torch
Expand All @@ -26,7 +26,18 @@

from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
from ._utils import (
_FillTypeJIT,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
_register_five_ten_crop_kernel_internal,
_register_kernel_internal,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()
if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]


def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
Expand Down Expand Up @@ -62,6 +73,14 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)


def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit of a nitpick, but could we rename the function to _horizontal_flip_cvcuda, CV-CUDA only operates on one datatype so the extra "image" in the funcname does not add value IMO. Removing it also mirrors the cvcuda_to_tensor and tensor_to_cvcuda functions

Copy link
Member

@NicolasHug NicolasHug Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cvcuda_to_tensor and tensor_to_cvcuda functions are a bit of outliers in that sense, but most other kernels specify the nature of the input they work on. We have e.g.

  • horizontal_flip_image for tensors and tv_tensor.Image
  • _horizontal_flip_image_pil
  • horizontal_flip_mask
  • horizontal_flip_bounding_boxes
  • etc.

The CVCUDA backend is basically of the same nature as the PIL backend. So It makes sense to keep it named _horizontal_flip_cvcuda (EDIT: meant _horizontal_flip_image_cvcuda!!) IMO., like we have _horizontal_flip_image_pil.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug just to be sure, you are saying it makes sense to keep it named _horizontal_flip_cvcuda, I guess you mean it makes sense to keep it named _horizontal_flip_image_cvcuda?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for catching! I'll edit above to avoid further confusion

return _import_cvcuda().flip(image, flipCode=1)


if CVCUDA_AVAILABLE:
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda)


@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(mask)
Expand Down Expand Up @@ -150,6 +169,14 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.vflip(image)


def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
return _import_cvcuda().flip(image, flipCode=0)


if CVCUDA_AVAILABLE:
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda)


@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(mask)
Expand Down
8 changes: 8 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,11 @@ def _is_cvcuda_available():
return True
except ImportError:
return False


def _is_cvcuda_tensor(inpt: Any) -> bool:
try:
cvcuda = _import_cvcuda()
return isinstance(inpt, cvcuda.Tensor)
except ImportError:
return False