@@ -24,24 +24,24 @@ class SafeLoad:
2424 In case of failure, it will attempt to find out as precisely as possible where the problem comes from.
2525 """
2626
27- def __init__ (self , model : "FastLLMModel" , * , num_shards : int , timeout : float | None = None ):
27+ def __init__ (self , model : "FastLLMModel" , * , shard_names : tuple [ str , ...] , timeout : float | None = None ):
2828 self ._model = model
2929 self ._distributed = self ._model .distributed
30- self ._num_shards = num_shards
31- self ._self_shard = self ._model .state_shard [: self . _num_shards ]
30+ # self._num_shards = num_shards
31+ self ._self_shards = { shard_name : self ._model .get_shard ( shard_name ) for shard_name in shard_names }
3232 self ._timeout = timeout
3333
3434 def __enter__ (self ) -> "SafeLoad" :
3535 self ._loaded = 0
3636 self ._loaded_parameters = {}
3737 # Track the number of loaded entries.
3838 # Use nan to mark non-loaded entries.
39- triton_fill (self ._self_shard , math .nan )
39+ for self_shard in self ._self_shards .values ():
40+ triton_fill (self_shard , math .nan )
4041 # Reset and count shard pads
41- for shard in self ._model .state_shard [: self ._num_shards ]:
42- shard_split = shard .split (self ._model .stage_shard_sizes , 0 )
43- for stage , stage_shard in zip (self ._model .stages_on_device .values (), shard_split ):
44- self ._loaded += stage .reset_shard_pad (stage_shard )
42+ for _ , fsdp , fsdp_shards in self ._model .split_shards_by_fsdp (self ._self_shards ):
43+ for fsdp_shard in fsdp_shards .values ():
44+ self ._loaded += fsdp .reset_shard_pad (fsdp_shard )
4545 return self
4646
4747 def __exit__ (self , exc_type , exc_val , exc_tb ):
@@ -70,18 +70,19 @@ def _validate(self) -> None:
7070 logger .info (f"{ self ._loaded :,} state entries loaded successfully" )
7171
7272 def _check_counter (self , errors : list [str ]) -> None :
73- to_load = self ._self_shard . numel ( )
73+ to_load = sum ( self_shard . numel () for self_shard in self ._self_shards . values () )
7474 if self ._loaded != to_load :
7575 # Ensure the right amount of weights is loaded.
7676 errors .append (f"Loaded a total of { self ._loaded :,} , state entries, expected { to_load :,} " )
7777
7878 def _check_missing (self , errors : list [str ]) -> None :
7979 # Ensure the loaded weights have a 1-1 mapping by looking for nans.
80- missing = self . _self_shard . new_zeros ([], dtype = torch .int64 )
80+ missing = torch . zeros ([], dtype = torch .int64 , device = self . _distributed . device )
8181 # Count nans in slices of 100M parameters to limit memory usage.
8282 # TODO: Find better solution (triton kernel?)
83- for shard_slice in self ._self_shard .flatten ().split (100000000 ):
84- missing += shard_slice .isnan ().sum ()
83+ for shard in self ._self_shards .values ():
84+ for shard_slice in shard .flatten ().split (100000000 ):
85+ missing += shard_slice .isnan ().sum ()
8586 local_missing = missing .item ()
8687 if self ._distributed .world_group is not None :
8788 all_reduce (missing , group = self ._distributed .world_group )
@@ -90,32 +91,32 @@ def _check_missing(self, errors: list[str]) -> None:
9091 errors .append (f"{ global_missing :,} state entries failed to load or corrupted (local={ local_missing :,} )." )
9192 # Determine where the missing values are coming from.
9293 global_total , local_total = 0 , 0
93- for shard_name , shard_ in zip (self ._model .state_shard_names [: self ._num_shards ], self ._self_shard ):
94- shard_split = shard_ .split (self ._model .stage_shard_sizes , 0 )
95- for stage , shard in zip (self ._model .stages_on_device .values (), shard_split ):
96- buffer = stage ._reconstruct_from_shard (shard )
97- for i , parameter in enumerate (stage ._split_buffer (buffer )):
94+ for stage , fsdp , fsdp_shards in self ._model .split_shards_by_fsdp (self ._self_shards ):
95+ for shard_name , fsdp_shard in fsdp_shards .items ():
96+ buffer = fsdp .reconstruct_from_shard (fsdp_shard )
97+ for parameter_name , parameter in fsdp .split_buffer (buffer ).items ():
9898 missing_for_param = parameter .isnan ().sum ().item ()
9999 if missing_for_param > 0 :
100100 global_total += missing_for_param
101- local_values = stage . _split_shard ( shard )[ i ]
101+ local_values = fsdp . split_shard ( fsdp_shard )[ parameter_name ]
102102 local_missing_for_param = local_values .isnan ().sum ().item ()
103103 local_total += local_missing_for_param
104104 errors .append (
105- f"{ missing_for_param :,} values missing out of { parameter .numel ():,} for parameter { stage . parameter_names [ i ] } in stage { stage .index } , shard { shard_name } "
105+ f"{ missing_for_param :,} values missing out of { parameter .numel ():,} for parameter { parameter_name } in stage { stage .index } , shard { shard_name } "
106106 f" (locally { local_missing_for_param :,} out of { local_values .numel ():,} )"
107107 )
108- missing_for_pad = buffer [- stage ._global_pad :].isnan ().sum ().item ()
108+ missing_for_pad = buffer [- fsdp ._global_pad :].isnan ().sum ().item ()
109109 if missing_for_pad > 0 :
110110 global_total += missing_for_pad
111111 local_missing_for_pad = (
112- shard [ - stage ._shard_pad :].isnan ().sum ().item () if stage ._shard_pad > 0 else 0
112+ fsdp_shard [ - fsdp ._shard_pad :].isnan ().sum ().item () if fsdp ._shard_pad > 0 else 0
113113 )
114114 local_total += local_missing_for_pad
115115 errors .append (
116- f"{ missing_for_pad :,} values missing out of { stage ._global_pad :,} for padding in stage { stage .index } , shard { shard_name } "
117- f" (locally { local_missing_for_pad :,} out of { stage ._shard_pad :,} )"
116+ f"{ missing_for_pad :,} values missing out of { fsdp ._global_pad :,} for padding in stage { stage .index } , shard { shard_name } "
117+ f" (locally { local_missing_for_pad :,} out of { fsdp ._shard_pad :,} )"
118118 )
119+
119120 if global_total != global_missing :
120121 errors .append (
121122 f"Incorrect global breakdown of missing state entries (expected { global_missing :,} , got { global_total :,} )"
@@ -127,7 +128,7 @@ def _check_missing(self, errors: list[str]) -> None:
127128
128129 def _check_parameters (self , errors : list [str ]) -> None :
129130 loaded_shard_names = set (self ._loaded_parameters )
130- shard_names = set (self ._model . state_shard_names [: self . _num_shards ] )
131+ shard_names = set (self ._self_shards )
131132 if loaded_shard_names != shard_names :
132133 errors .append (f"Incorrect loaded shards: { loaded_shard_names } !={ shard_names } " )
133134 for shard_name in shard_names & loaded_shard_names :
0 commit comments