Skip to content

Commit 2525677

Browse files
committed
Feat: support DTensor for storage size and id
1 parent 557576d commit 2525677

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

src/huggingface_hub/serialization/_torch.py

+18
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,15 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
706706
if the input is a wrapper tensor subclass Tensor
707707
"""
708708

709+
try:
710+
from torch.distributed.tensor import DTensor
711+
712+
if isinstance(tensor, DTensor):
713+
local_tensor = tensor.to_local()
714+
return local_tensor.storage().data_ptr()
715+
except ImportError:
716+
pass
717+
709718
try:
710719
# for torch 2.1 and above we can also handle tensor subclasses
711720
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
@@ -753,6 +762,15 @@ def get_torch_storage_size(tensor: "torch.Tensor") -> int:
753762
"""
754763
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
755764
"""
765+
try:
766+
from torch.distributed.tensor import DTensor
767+
768+
if isinstance(tensor, DTensor):
769+
# this returns the size of the FULL tensor in bytes
770+
return tensor.nbytes
771+
except ImportError:
772+
pass
773+
756774
try:
757775
# for torch 2.1 and above we can also handle tensor subclasses
758776
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

0 commit comments

Comments
 (0)