2121import logging
2222from concurrent .futures import as_completed
2323from contextlib import nullcontext
24- from typing import Any , Dict , List , NamedTuple , Optional , TYPE_CHECKING , Union , Tuple
24+ from typing import Dict , List , NamedTuple , Optional , TYPE_CHECKING
2525
2626import torch
27- import torch .nn as nn
2827
2928from ..looper .dequantize_processor import DequantizeProcessor
3029from ..looper .eora_processor import EoraProcessor
@@ -121,6 +120,10 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
121120 if not quant_devices :
122121 quant_devices = [CPU ]
123122
123+ self ._quant_devices = quant_devices
124+ self ._quant_device_rr = 0
125+ self ._module_device_map : Dict [str , torch .device ] = {}
126+ self ._quant_device_lock = threading .Lock ()
124127 vram_strategy = getattr (self .gptq_model .quantize_config , "vram_strategy" , VramStrategy .EXCLUSIVE )
125128 if isinstance (vram_strategy , str ):
126129 try :
@@ -139,27 +142,8 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
139142 vram_strategy = VramStrategy .EXCLUSIVE
140143 self ._vram_strategy = vram_strategy
141144
142- # Apply compute device filter if provided to determine which devices to use for quantization
143- compute_device_filter = getattr (self .gptq_model .quantize_config , "compute_device_filter" , None )
144- if compute_device_filter is not None :
145- quant_devices_filtered = compute_device_filter (quant_devices )
146- if len (quant_devices_filtered ) >= 1 :
147- quant_devices = quant_devices_filtered
148- else :
149- log .warn (
150- "compute_device_filter returned empty device list. "
151- "Using all devices for quantization."
152- )
153-
154- self ._quant_devices = quant_devices
155- self ._quant_device_rr = 0
156- self ._module_device_map : Dict [str , torch .device ] = {}
157- self ._quant_device_lock = threading .Lock ()
158145 self ._moe_subset_threshold = 16
159146 self ._subset_callback = getattr (self .gptq_model , "subset_callback" , None )
160-
161- # Track current subset for MoE lifecycle hooks
162- self ._current_subset : Optional [Dict [str , Any ]] = None
163147
164148 # moe_routing_override is only required for MoE models (i.e., models with dynamic_expert_index).
165149 if getattr (self .gptq_model , "dynamic_expert_index" , None ):
@@ -199,57 +183,6 @@ def __exit__(self, exc_type, exc, tb):
199183 restore_moe_topk (self ._state )
200184 return False # Do not suppress exceptions
201185
202- class MoELifecycleContext :
203- """Context manager for MoE lifecycle hooks integration."""
204-
205- def __init__ (self , module_looper , module , processor , current_subset ):
206- self .module_looper = module_looper
207- self .module = module
208- self .processor = processor
209- self .current_subset = current_subset
210- self .moe_hooks_active = False
211- self .moe_block = None
212- self .moe_forward_original = None
213-
214- def __enter__ (self ):
215- """Set up MoE lifecycle hooks if applicable."""
216- if self .module_looper ._should_use_moe_lifecycle (self .module , self .processor ):
217- hooks = self .module_looper .gptq_model .moe_lifecycle_hooks
218- self .moe_block = hooks .get_moe_block (self .module , self .module_looper .gptq_model .__class__ )
219-
220- if self .moe_block is not None :
221- # Save original forward method
222- self .moe_forward_original = self .moe_block .forward
223-
224- # Create wrapper that forwards to all experts
225- moe_block_prefix = hooks ._extract_moe_block_prefix (self .current_subset , self .moe_block )
226-
227- def moe_forward_wrapper (hidden_states , ** kwargs ):
228- return hooks .forward_to_all_experts (
229- moe_block = self .moe_block ,
230- hidden_states = hidden_states ,
231- processor = self .processor ,
232- subset = self .current_subset ,
233- original_forward = self .moe_forward_original ,
234- model_class = self .module_looper .gptq_model .__class__ ,
235- module_looper = self .module_looper , # Pass for TLS-based hooks pausing
236- moe_block_prefix = moe_block_prefix ,
237- replica_module = self .module , # Pass replica for device-correct module resolution
238- ** kwargs
239- )
240-
241- # Temporarily replace forward method
242- self .moe_block .forward = moe_forward_wrapper
243- self .moe_hooks_active = True
244-
245- return self
246-
247- def __exit__ (self , exc_type , exc_val , exc_tb ):
248- """Restore original MoE forward method if it was patched."""
249- if self .moe_hooks_active and self .moe_forward_original is not None and self .moe_block is not None :
250- self .moe_block .forward = self .moe_forward_original
251- return False # Don't suppress exceptions
252-
253186 def register_layer_callback (self , callback ) -> None :
254187 """Register or replace the layer-complete callback target."""
255188 self ._layer_callback = callback
@@ -422,22 +355,6 @@ def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local:
422355 setattr (processor , "_mask_tls" , tls )
423356 return tls
424357
425- def _processor_hooks_paused_tls (self , processor ):
426- """Get or create thread-local storage for hooks_paused flag."""
427- if not hasattr (processor , "_hooks_paused_tls" ):
428- processor ._hooks_paused_tls = threading .local ()
429- return processor ._hooks_paused_tls
430-
431- def _set_processor_hooks_paused (self , processor : LoopProcessor , paused : bool ):
432- """Set hooks paused state for current thread."""
433- tls = self ._processor_hooks_paused_tls (processor )
434- tls .value = paused
435-
436- def _get_processor_hooks_paused (self , processor : LoopProcessor ) -> bool :
437- """Get hooks paused state for current thread (thread-safe)."""
438- tls = getattr (processor , "_hooks_paused_tls" , None )
439- return getattr (tls , "value" , False ) if tls else False
440-
441358 def _set_processor_mask (self , processor : LoopProcessor , mask ):
442359 tls = self ._processor_mask_tls (processor )
443360 tls .value = mask
@@ -543,10 +460,6 @@ def _extract_moe_group_key(self, module_name: Optional[str]) -> Optional[str]:
543460 prefix , _ = module_name .split (".shared_experts." , 1 )
544461 return f"{ prefix } .shared_experts"
545462
546- if ".shared_expert." in module_name :
547- prefix , _ = module_name .split (".shared_expert." , 1 )
548- return f"{ prefix } .shared_expert"
549-
550463 return None
551464
552465 def _is_attention_module_name (self , module_name : str ) -> bool :
@@ -561,43 +474,6 @@ def _is_attention_module_name(self, module_name: str) -> bool:
561474 if lowered .endswith ("attn" ) or lowered .endswith ("attention" ):
562475 return True
563476 return False
564-
565- def _should_use_moe_lifecycle (self , module : nn .Module , processor : LoopProcessor ) -> bool :
566- """
567- Check if MoE lifecycle hooks should be used for this module.
568-
569- Returns True if:
570- - pass_whole_dataset_to_each_expert flag is enabled
571- - Model has lifecycle hooks configured
572- - Module contains an MoE block
573- """
574- # Check if feature is enabled
575- moe_routing_bypass = getattr (self .gptq_model .quantize_config , "moe_routing_bypass" , None )
576- flag_enabled = moe_routing_bypass () if callable (moe_routing_bypass ) else False
577- if not flag_enabled :
578- return False
579-
580- # Check if model has lifecycle hooks
581- hooks = getattr (self .gptq_model , 'moe_lifecycle_hooks' , None )
582- if hooks is None :
583- log .warn (
584- f"pass_whole_dataset_to_each_expert is enabled but { self .gptq_model .__class__ .__name__ } "
585- f"model does not have 'moe_lifecycle_hooks' configured. MoE optimization will be disabled. "
586- f"Please ensure your model definition has proper MoE lifecycle hooks configured."
587- )
588- return False
589-
590- # Check if this module contains an MoE block
591- moe_block = hooks .get_moe_block (module , self .gptq_model .__class__ )
592- if moe_block is None :
593- log .warn (
594- f"pass_whole_dataset_to_each_expert is enabled but no MoE block found in module "
595- f"{ module .__class__ .__name__ } . MoE optimization will be disabled for this module. "
596- f"This may indicate an issue with the model's MoE configuration or module structure."
597- )
598- return False
599-
600- return True
601477
602478 def _assign_quant_device_for_module (
603479 self ,
@@ -910,8 +786,6 @@ def _run_forward_batches_single(
910786 if attn_tensor is not None :
911787 seq_len = layer_input [0 ].shape [1 ] if (len (layer_input ) > 0 and layer_input [0 ].dim () >= 2 ) else None
912788 keep_mask = normalize_seq_mask (attn_tensor , seq_len = seq_len )
913-
914- # Set mask using TLS (thread-safe)
915789 self ._set_processor_mask (processor , keep_mask )
916790
917791 additional_inputs : Dict [str , torch .Tensor ] = {}
@@ -938,8 +812,8 @@ def _run_forward_batches_single(
938812 if not preserve_module_devices :
939813 rehome_module_to_device (module , cur_layer_device , move_parameters = True , move_buffers = True )
940814
941- # MoE lifecycle hooks integration - using context manager
942- with self .MoERoutingOverrideContext (module , self .moe_routing_override ) if self .moe_routing_override else self . MoELifecycleContext ( self , module , processor , self . _current_subset ) :
815+ # MoE lifecycle hooks integration - using context manager
816+ with self .MoERoutingOverrideContext (module , self .moe_routing_override ) if self .moe_routing_override else nullcontext :
943817 module_output = None
944818 try :
945819 if is_lm_head_module :
@@ -951,12 +825,6 @@ def _run_forward_batches_single(
951825 finally :
952826 self ._set_processor_mask (processor , None )
953827
954- # Release intermediate tensors promptly after they are no longer needed
955- del layer_input
956- del attn_tensor
957- del keep_mask
958- del additional_inputs
959-
960828 if (
961829 reuse_kv
962830 and module_output is not None
@@ -971,10 +839,6 @@ def _run_forward_batches_single(
971839 primary = move_to (primary , device = cur_layer_device )
972840 outputs .append ([primary ])
973841
974- # Release module_output promptly after extracting what we need
975- if module_output is not None :
976- del module_output
977-
978842 rows_for_batch = batch_row_counts [batch_idx ] if batch_idx < len (batch_row_counts ) else 0
979843 if rows_for_batch <= 0 :
980844 rows_for_batch = self ._batch_row_count (layer_inputs [batch_idx ]) if layer_inputs and batch_idx < len (layer_inputs ) else 1
@@ -1077,8 +941,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
1077941
1078942 # Ensure any async replication/memcpy ops are complete before threads start fanning out.
1079943 torch_sync ()
1080-
1081- # Clone modules FIRST, then apply MoE lifecycle hooks to all replicas
944+
1082945 try :
1083946 module_replicas = clone_module_for_devices (
1084947 module ,
@@ -1101,8 +964,8 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
1101964 ctx = None
1102965 if self .moe_routing_override :
1103966 ctx = self .MoERoutingOverrideContext (replica , self .moe_routing_override )
1104- elif self . _should_use_moe_lifecycle ( module , processor ) :
1105- ctx = self . MoELifecycleContext ( self , replica , processor , self . _current_subset )
967+ else :
968+ ctx = nullcontext
1106969
1107970 if ctx :
1108971 ctx .__enter__ ()
@@ -1221,7 +1084,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
12211084 # ensure replicas release promptly and free GPU memory
12221085 for dev in list (module_replicas .keys ()):
12231086 del module_replicas [dev ]
1224-
1087+
12251088 if not need_outputs :
12261089 return []
12271090
@@ -1243,10 +1106,6 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
12431106
12441107 def _masked_hook_wrapper (self , processor : LoopProcessor , inner_hook , hook_source : str ):
12451108 def hook (module , inputs , output ):
1246- # Thread-safe check if hooks are paused (TLS-based, per-thread)
1247- if self ._get_processor_hooks_paused (processor ):
1248- return
1249-
12501109 keep = self ._get_processor_mask (processor )
12511110
12521111 timer = getattr (self .gptq_model , "quant_region_timer" , None )
@@ -1292,51 +1151,6 @@ def hook(module, inputs, output):
12921151 source = hook_source ,
12931152 )
12941153 return hook
1295-
1296- def _masked_pre_hook_wrapper (self , processor : LoopProcessor , inner_hook , hook_source : str ):
1297- """
1298- Pre-forward hook wrapper for MoE expert modules.
1299- This is called BEFORE forward executes (when used with HookedLinear.forward_hook).
1300- Respects hooks_paused state to avoid double-counting during intermediate calculations.
1301- """
1302- def pre_hook (module , inputs , output ):
1303- # Thread-safe check if hooks are paused (TLS-based, per-thread)
1304- if self ._get_processor_hooks_paused (processor ):
1305- return
1306-
1307- # Get mask using TLS (thread-safe)
1308- keep = self ._get_processor_mask (processor )
1309-
1310- timer = getattr (self .gptq_model , "quant_region_timer" , None )
1311- start = time .perf_counter () if timer else None
1312-
1313- # Mask first tensor-like input if it's [B, S, ...]
1314- new_inputs = inputs
1315- try :
1316- if isinstance (inputs , (tuple , list )) and len (inputs ) > 0 and torch .is_tensor (inputs [0 ]):
1317- x = inputs [0 ]
1318- if keep is not None and x .dim () >= 3 :
1319- xk = apply_keep_mask_bt (x , keep )
1320- if isinstance (inputs , tuple ):
1321- new_inputs = (xk ,) + tuple (inputs [1 :])
1322- else :
1323- new_inputs = [xk ] + list (inputs [1 :])
1324- except Exception :
1325- # Never break the forward due to masking
1326- new_inputs = inputs
1327-
1328- # Call inner hook with inputs and output (GPTQ ignores output anyway)
1329- try :
1330- inner_hook (module , new_inputs , output )
1331- finally :
1332- if timer is not None and start is not None :
1333- timer .record (
1334- "forward_pre_hook" ,
1335- time .perf_counter () - start ,
1336- source = hook_source ,
1337- )
1338-
1339- return pre_hook
13401154
13411155 def cache_inputs (self , layers , calibration_data , use_cache ):
13421156 capture_stage = StageInputsCapture (self , logger = log )
0 commit comments