diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index ccb9c42b92..daa4154b45 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -706,6 +706,15 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: if the input is a wrapper tensor subclass Tensor """ + try: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + local_tensor = tensor.to_local() + return local_tensor.storage().data_ptr() + except ImportError: + pass + try: # for torch 2.1 and above we can also handle tensor subclasses from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -753,6 +762,15 @@ def get_torch_storage_size(tensor: "torch.Tensor") -> int: """ Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 """ + try: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + # this returns the size of the FULL tensor in bytes + return tensor.nbytes + except ImportError: + pass + try: # for torch 2.1 and above we can also handle tensor subclasses from torch.utils._python_dispatch import is_traceable_wrapper_subclass diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 9aef755a70..dad7065de6 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -46,6 +46,16 @@ def is_wrapper_tensor_subclass_available(): return False +def is_dtensor_available(): + try: + from torch.distributed.device_mesh import init_device_mesh # type: ignore[import] # noqa: F401 + from torch.distributed.tensor import DTensor # type: ignore[import] # noqa: F401 + + return True + except ImportError: + return False + + @pytest.fixture def dummy_state_dict() -> Dict[str, List[int]]: return { @@ -250,6 +260,33 @@ def test_get_torch_storage_size(): assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 +@requires("torch") +@pytest.mark.skipif(not is_dtensor_available(), reason="requires torch with dtensor available") +def test_get_torch_storage_size_dtensor(): + # testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough + import torch + import torch.distributed as dist + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.tensor import DTensor, Replicate + + if dist.is_available() and not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + store=dist.HashStore(), + rank=0, + world_size=1, + ) + + mesh = init_device_mesh("cpu", (1,)) + local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16) + dt = DTensor.from_local(local, mesh, [Replicate()]) + + assert get_torch_storage_size(dt) == 5 * 2 + + if dist.is_initialized(): + dist.destroy_process_group() + + @requires("torch") @pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher") def test_get_torch_storage_size_wrapper_tensor_subclass():