Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,15 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
if the input is a wrapper tensor subclass Tensor
"""

try:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
local_tensor = tensor.to_local()
return local_tensor.storage().data_ptr()
except ImportError:
pass

try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
Expand Down Expand Up @@ -753,6 +762,15 @@ def get_torch_storage_size(tensor: "torch.Tensor") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
"""
try:
from torch.distributed.tensor import DTensor

if isinstance(tensor, DTensor):
# this returns the size of the FULL tensor in bytes
return tensor.nbytes
except ImportError:
pass
Comment on lines +765 to +772
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with DTensor, but if the tensor is indeed a DTensor and the import fails line 766, would it be okay to fallback to tensor.untyped_storage().nbytes() ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, will have to test locally, but I'm pretty sure that would fail on has no method untyped_storage. But this import shouldn't ever fail if the tensor is DTensor. It's wrapped in try/except to avoid version checking as DTensor is torch >= 2.1 (ish).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTensor is torch >= 2.1 (ish)

okay then all good!


try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
Expand Down
37 changes: 37 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def is_wrapper_tensor_subclass_available():
return False


def is_dtensor_available():
try:
from torch.distributed.device_mesh import init_device_mesh # type: ignore[import] # noqa: F401
from torch.distributed.tensor import DTensor # type: ignore[import] # noqa: F401

return True
except ImportError:
return False


@pytest.fixture
def dummy_state_dict() -> Dict[str, List[int]]:
return {
Expand Down Expand Up @@ -250,6 +260,33 @@ def test_get_torch_storage_size():
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2


@requires("torch")
@pytest.mark.skipif(not is_dtensor_available(), reason="requires torch with dtensor available")
def test_get_torch_storage_size_dtensor():
# testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate

if dist.is_available() and not dist.is_initialized():
dist.init_process_group(
backend="gloo",
store=dist.HashStore(),
rank=0,
world_size=1,
)

mesh = init_device_mesh("cpu", (1,))
local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
dt = DTensor.from_local(local, mesh, [Replicate()])

assert get_torch_storage_size(dt) == 5 * 2

if dist.is_initialized():
dist.destroy_process_group()


@requires("torch")
@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_get_torch_storage_size_wrapper_tensor_subclass():
Expand Down