@@ -249,6 +249,31 @@ def test_get_torch_storage_size():
249
249
assert get_torch_storage_size (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .float64 )) == 5 * 8
250
250
assert get_torch_storage_size (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .float16 )) == 5 * 2
251
251
252
+ @requires ("torch" )
253
+ def test_get_torch_storage_size_dtensor ():
254
+ # testing distributed sharded tensors isn't very easy, would need to subprocess call torchrun, so this should be good enough
255
+ import torch
256
+ import torch .distributed as dist
257
+ from torch .distributed .device_mesh import init_device_mesh
258
+ from torch .distributed .tensor import DTensor , Replicate
259
+
260
+ if dist .is_available () and not dist .is_initialized ():
261
+ dist .init_process_group (
262
+ backend = "gloo" ,
263
+ store = dist .HashStore (),
264
+ rank = 0 ,
265
+ world_size = 1 ,
266
+ )
267
+
268
+ mesh = init_device_mesh ("cpu" , (1 ,))
269
+ local = torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .float16 )
270
+ dt = DTensor .from_local (local , mesh , [Replicate ()])
271
+
272
+ assert get_torch_storage_size (dt ) == 5 * 2
273
+
274
+ if dist .is_initialized ():
275
+ dist .destroy_process_group ()
276
+
252
277
253
278
@requires ("torch" )
254
279
@pytest .mark .skipif (not is_wrapper_tensor_subclass_available (), reason = "requires torch 2.1 or higher" )
0 commit comments