Skip to content

Commit 2768611

Browse files
committed
Simple tests
1 parent 410ccf8 commit 2768611

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/test_serialization.py

+25
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,31 @@ def test_get_torch_storage_size():
249249
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8
250250
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2
251251

252+
@requires("torch")
253+
def test_get_torch_storage_size_dtensor():
254+
# testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough
255+
import torch
256+
import torch.distributed as dist
257+
from torch.distributed.device_mesh import init_device_mesh
258+
from torch.distributed.tensor import DTensor, Replicate
259+
260+
if dist.is_available() and not dist.is_initialized():
261+
dist.init_process_group(
262+
backend="gloo",
263+
store=dist.HashStore(),
264+
rank=0,
265+
world_size=1,
266+
)
267+
268+
mesh = init_device_mesh("cpu", (1,))
269+
local = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
270+
dt = DTensor.from_local(local, mesh, [Replicate()])
271+
272+
assert get_torch_storage_size(dt) == 5 * 2
273+
274+
if dist.is_initialized():
275+
dist.destroy_process_group()
276+
252277

253278
@requires("torch")
254279
@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")

0 commit comments

Comments
 (0)