Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
44db71c
implement additional cvcuda infra for all branches to avoid duplicate…
justincdavis Nov 25, 2025
e3dd700
update make_image_cvcuda to have default batch dim
justincdavis Nov 25, 2025
c035df1
add stanardized setup to main for easier updating of PRs and branches
justincdavis Dec 2, 2025
98d7dfb
update is_cvcuda_tensor
justincdavis Dec 2, 2025
ddc116d
add cvcuda to pil compatible to transforms by default
justincdavis Dec 2, 2025
e51dc7e
remove cvcuda from transform class
justincdavis Dec 2, 2025
e14e210
merge with main
justincdavis Dec 4, 2025
4939355
resolve more formatting naming
justincdavis Dec 4, 2025
ec76196
initial draft of to_dtype_cvcuda
justincdavis Nov 18, 2025
bd823cf
fix: to_dtype_cvcuda conventions
justincdavis Nov 20, 2025
f7aa94a
remove staticmethod from reference todtype
justincdavis Nov 24, 2025
b21d9f0
add docstring for explain scaling setup, combine correctness checks
justincdavis Nov 24, 2025
973e058
resolve more review comments
justincdavis Nov 24, 2025
d871331
simplify todtype testing
justincdavis Nov 26, 2025
736a2e6
add int -> int scaling setup for cvcuda, use bit diff for scale
justincdavis Nov 26, 2025
7a231b1
further simplify todtype test
justincdavis Nov 26, 2025
d3e4573
update todtype based on PR reviews
justincdavis Dec 2, 2025
ec93ba3
cleanup commnet, variable names
justincdavis Dec 2, 2025
89122db
update to_dtype_cvcuda name
justincdavis Dec 4, 2025
1b0d295
update to standards from flip PR
justincdavis Dec 4, 2025
009f925
remove cvcuda updates to augment
justincdavis Dec 4, 2025
41af724
remove cvcuda refs from color
justincdavis Dec 4, 2025
d12e4df
refactor dtype converters to be in utils
justincdavis Dec 4, 2025
c198cf0
add type checking for cvcuda
justincdavis Dec 4, 2025
18df67f
provide better error for todtype
justincdavis Dec 4, 2025
c5a2a5a
refactor to simplify setup for dtype conversions
justincdavis Dec 5, 2025
915ffb1
Merge branch 'main' into feat/dtype_cvcuda
justincdavis Dec 5, 2025
7f41c95
fix: not testing transform class correctness in ToDtype, resolved
justincdavis Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
Expand All @@ -41,7 +42,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -2627,7 +2627,17 @@ def test_kernel(self, kernel, make_input, input_dtype, output_dtype, device, sca
scale=scale,
)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand All @@ -2642,18 +2652,27 @@ def test_functional(self, make_input, input_dtype, output_dtype, device, scale):

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
[
make_image_tensor,
make_image,
make_bounding_boxes,
make_segmentation_mask,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
@pytest.mark.parametrize("as_dict", (True, False))
def test_transform(self, make_input, input_dtype, output_dtype, device, scale, as_dict):
input = make_input(dtype=input_dtype, device=device)
inpt = make_input(dtype=input_dtype, device=device)
if as_dict:
output_dtype = {type(input): output_dtype}
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), input, check_sample_input=not as_dict)
output_dtype = {type(inpt): output_dtype}
check_transform(transforms.ToDtype(dtype=output_dtype, scale=scale), inpt, check_sample_input=not as_dict)
Comment on lines +2672 to +2675
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask what the reason is for changing "input" to "inpt"?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input is a reserved keyword in Python, while updating the function I went ahead and changed this to not overwrite it


def reference_convert_dtype_image_tensor(self, image, dtype=torch.float, scale=False):
input_dtype = image.dtype
Expand Down Expand Up @@ -2688,25 +2707,59 @@ def fn(value):

return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)

def _get_dtype_conversion_atol(self, input_dtype, output_dtype, scale):
is_uint16_to_uint8 = input_dtype == torch.uint16 and output_dtype == torch.uint8
is_uint8_to_uint16 = input_dtype == torch.uint8 and output_dtype == torch.uint16
changes_type_class = output_dtype.is_floating_point != input_dtype.is_floating_point

in_bits = torch.iinfo(input_dtype).bits if not input_dtype.is_floating_point else None
out_bits = torch.iinfo(output_dtype).bits if not output_dtype.is_floating_point else None
expands_bits = in_bits is not None and out_bits is not None and out_bits > in_bits

if is_uint16_to_uint8:
atol = 255
elif is_uint8_to_uint16 and not scale:
atol = 255
elif expands_bits and not scale:
atol = 1
elif changes_type_class:
atol = 1
else:
atol = 0

return atol

@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("fn", [F.to_dtype, transform_cls_to_functional(transforms.ToDtype)])
def test_image_correctness(self, input_dtype, output_dtype, device, scale, make_input, fn):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")

input = make_image(dtype=input_dtype, device=device)
inpt = make_input(dtype=input_dtype, device=device)
out = fn(inpt, dtype=output_dtype, scale=scale)

out = F.to_dtype(input, dtype=output_dtype, scale=scale)
expected = self.reference_convert_dtype_image_tensor(input, dtype=output_dtype, scale=scale)
if make_input == make_image_cvcuda:
inpt = F.cvcuda_to_tensor(inpt)
out = F.cvcuda_to_tensor(out)

if input_dtype.is_floating_point and not output_dtype.is_floating_point and scale:
torch.testing.assert_close(out, expected, atol=1, rtol=0)
else:
torch.testing.assert_close(out, expected)
expected = self.reference_convert_dtype_image_tensor(inpt, dtype=output_dtype, scale=scale)

atol = self._get_dtype_conversion_atol(input_dtype, output_dtype, scale)
torch.testing.assert_close(out, expected, rtol=0, atol=atol)

def was_scaled(self, inpt):
# this assumes the target dtype is float
Expand Down
19 changes: 16 additions & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor

from ._utils import (
_parse_labels_getter,
Expand All @@ -21,6 +22,9 @@
)


CVCUDA_AVAILABLE = _is_cvcuda_available()


# TODO: do we want/need to expose this?
class Identity(Transform):
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
Expand Down Expand Up @@ -267,7 +271,8 @@ class ToDtype(Transform):
Default: ``False``.
"""

_transformed_types = (torch.Tensor,)
if CVCUDA_AVAILABLE:
_transformed_types = Transform._transformed_types + (_is_cvcuda_tensor,)

def __init__(
self, dtype: Union[torch.dtype, dict[Union[type, str], Optional[torch.dtype]]], scale: bool = False
Expand All @@ -294,7 +299,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
if isinstance(self.dtype, torch.dtype):
# For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype
# is a simple torch.dtype
if not is_pure_tensor(inpt) and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
if (
not is_pure_tensor(inpt)
and not isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
and (CVCUDA_AVAILABLE and not _is_cvcuda_tensor(inpt))
):
return inpt

dtype: Optional[torch.dtype] = self.dtype
Expand All @@ -311,7 +320,11 @@ def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
'e.g. dtype={tv_tensors.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.'
)

supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
supports_scaling = (
is_pure_tensor(inpt)
or isinstance(inpt, (tv_tensors.Image, tv_tensors.Video))
or (CVCUDA_AVAILABLE and _is_cvcuda_tensor(inpt))
)
if dtype is None:
if self.scale and supports_scaling:
warnings.warn(
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT, _is_cvcuda_tensor


def _setup_number_or_seq(arg: int | float | Sequence[int | float], name: str) -> Sequence[float]:
Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, _is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
_is_cvcuda_tensor,
),
)
}
Expand Down
22 changes: 21 additions & 1 deletion torchvision/transforms/v2/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def get_dimensions_video(video: torch.Tensor) -> list[int]:
return get_dimensions_image(video)


def get_dimensions_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:
# CV-CUDA tensor is always in NHWC layout
# get_dimensions is CHW
return [image.shape[3], image.shape[1], image.shape[2]]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_dimensions, cvcuda.Tensor)(get_dimensions_image_cvcuda)


def get_num_channels(inpt: torch.Tensor) -> int:
if torch.jit.is_scripting():
return get_num_channels_image(inpt)
Expand Down Expand Up @@ -87,6 +97,16 @@ def get_num_channels_video(video: torch.Tensor) -> int:
get_image_num_channels = get_num_channels


def get_num_channels_image_cvcuda(image: "cvcuda.Tensor") -> int:
# CV-CUDA tensor is always in NHWC layout
# get_num_channels is C
return image.shape[3]


if CVCUDA_AVAILABLE:
_register_kernel_internal(get_num_channels, cvcuda.Tensor)(get_num_channels_image_cvcuda)


def get_size(inpt: torch.Tensor) -> list[int]:
if torch.jit.is_scripting():
return get_size_image(inpt)
Expand Down Expand Up @@ -125,7 +145,7 @@ def get_size_image_cvcuda(image: "cvcuda.Tensor") -> list[int]:


if CVCUDA_AVAILABLE:
_get_size_image_cvcuda = _register_kernel_internal(get_size, cvcuda.Tensor)(get_size_image_cvcuda)
_register_kernel_internal(get_size, _import_cvcuda().Tensor)(get_size_image_cvcuda)


@_register_kernel_internal(get_size, tv_tensors.Video, tv_tensor_wrapper=False)
Expand Down
81 changes: 79 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, TYPE_CHECKING

import PIL.Image
import torch
Expand All @@ -13,7 +13,22 @@

from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import (
_get_cvcuda_type_from_torch_dtype,
_get_kernel,
_get_torch_dtype_from_cvcuda_type,
_import_cvcuda,
_is_cvcuda_available,
_register_kernel_internal,
is_pure_tensor,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def normalize(
Expand Down Expand Up @@ -340,6 +355,68 @@ def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: boo
return inpt.to(dtype)


def _to_dtype_image_cvcuda(
inpt: "cvcuda.Tensor",
dtype: torch.dtype,
scale: bool = False,
) -> "cvcuda.Tensor":
"""
Convert the dtype of a CV-CUDA tensor, based on a torch.dtype.

Args:
inpt: The CV-CUDA tensor to convert the dtype of.
dtype: The torch.dtype to convert the dtype to.
scale: Whether to scale the values to the new dtype.
There are four cases for the scaling setup:
1. float -> float
2. int -> int
3. float -> int
4. int -> float
If scale is True, the values will be scaled to the new dtype.
If scale is False, the values will not be scaled.
The scale values for float -> float are 1.0 and 0.0 respectively.
The scale values for int -> int are 2^(bit_diff) of the new dtype.
Where bit_diff is the difference in the number of bits of the new dtype and the input dtype.
The scale values for float -> int and int -> float are the maximum value of the new dtype.

Returns:
out (cvcuda.Tensor): The CV-CUDA tensor with the converted dtype.

"""
cvcuda = _import_cvcuda()

dtype_in = _get_torch_dtype_from_cvcuda_type(inpt.dtype)
cvc_dtype = _get_cvcuda_type_from_torch_dtype(dtype)

scale_val, offset = 1.0, 0.0
if scale:
in_dtype_float = dtype_in.is_floating_point
out_dtype_float = dtype.is_floating_point

if in_dtype_float and out_dtype_float:
scale_val, offset = 1.0, 0.0
elif not in_dtype_float and not out_dtype_float:
in_bits = torch.iinfo(dtype_in).bits
out_bits = torch.iinfo(dtype).bits
scale_val = float(2 ** (out_bits - in_bits))
offset = 0.0
elif in_dtype_float and not out_dtype_float:
scale_val, offset = float(_max_value(dtype)), 0.0
else:
scale_val, offset = 1.0 / float(_max_value(dtype_in)), 0.0

return cvcuda.convertto(
inpt,
dtype=cvc_dtype,
scale=scale_val,
offset=offset,
)


if CVCUDA_AVAILABLE:
_register_kernel_internal(to_dtype, _import_cvcuda().Tensor)(_to_dtype_image_cvcuda)


def sanitize_bounding_boxes(
bounding_boxes: torch.Tensor,
format: Optional[tv_tensors.BoundingBoxFormat] = None,
Expand Down
Loading