Skip to content

Commit 2f9ad8a

Browse files
committed
Revert "Feature: MoE.Routing control (Bypass or Override) (#2235)"
This reverts commit fc0de11.
1 parent 9683c41 commit 2f9ad8a

39 files changed

Lines changed: 74 additions & 1759 deletions

gptqmodel/looper/loop_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898

9999
self.inputs_cache: InputCache = InputCache(None, None, None, None)
100100
self.tasks = {}
101-
101+
102102
self.pb = None
103103
self.fwd_time = None
104104
self.layer_count = None

gptqmodel/looper/module_looper.py

Lines changed: 11 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
import logging
2222
from concurrent.futures import as_completed
2323
from contextlib import nullcontext
24-
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING, Union, Tuple
24+
from typing import Dict, List, NamedTuple, Optional, TYPE_CHECKING
2525

2626
import torch
27-
import torch.nn as nn
2827

2928
from ..looper.dequantize_processor import DequantizeProcessor
3029
from ..looper.eora_processor import EoraProcessor
@@ -121,6 +120,10 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
121120
if not quant_devices:
122121
quant_devices = [CPU]
123122

123+
self._quant_devices = quant_devices
124+
self._quant_device_rr = 0
125+
self._module_device_map: Dict[str, torch.device] = {}
126+
self._quant_device_lock = threading.Lock()
124127
vram_strategy = getattr(self.gptq_model.quantize_config, "vram_strategy", VramStrategy.EXCLUSIVE)
125128
if isinstance(vram_strategy, str):
126129
try:
@@ -139,27 +142,8 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]):
139142
vram_strategy = VramStrategy.EXCLUSIVE
140143
self._vram_strategy = vram_strategy
141144

142-
# Apply compute device filter if provided to determine which devices to use for quantization
143-
compute_device_filter = getattr(self.gptq_model.quantize_config, "compute_device_filter", None)
144-
if compute_device_filter is not None:
145-
quant_devices_filtered = compute_device_filter(quant_devices)
146-
if len(quant_devices_filtered) >= 1:
147-
quant_devices = quant_devices_filtered
148-
else:
149-
log.warn(
150-
"compute_device_filter returned empty device list. "
151-
"Using all devices for quantization."
152-
)
153-
154-
self._quant_devices = quant_devices
155-
self._quant_device_rr = 0
156-
self._module_device_map: Dict[str, torch.device] = {}
157-
self._quant_device_lock = threading.Lock()
158145
self._moe_subset_threshold = 16
159146
self._subset_callback = getattr(self.gptq_model, "subset_callback", None)
160-
161-
# Track current subset for MoE lifecycle hooks
162-
self._current_subset: Optional[Dict[str, Any]] = None
163147

164148
# moe_routing_override is only required for MoE models (i.e., models with dynamic_expert_index).
165149
if getattr(self.gptq_model, "dynamic_expert_index", None):
@@ -199,57 +183,6 @@ def __exit__(self, exc_type, exc, tb):
199183
restore_moe_topk(self._state)
200184
return False # Do not suppress exceptions
201185

202-
class MoELifecycleContext:
203-
"""Context manager for MoE lifecycle hooks integration."""
204-
205-
def __init__(self, module_looper, module, processor, current_subset):
206-
self.module_looper = module_looper
207-
self.module = module
208-
self.processor = processor
209-
self.current_subset = current_subset
210-
self.moe_hooks_active = False
211-
self.moe_block = None
212-
self.moe_forward_original = None
213-
214-
def __enter__(self):
215-
"""Set up MoE lifecycle hooks if applicable."""
216-
if self.module_looper._should_use_moe_lifecycle(self.module, self.processor):
217-
hooks = self.module_looper.gptq_model.moe_lifecycle_hooks
218-
self.moe_block = hooks.get_moe_block(self.module, self.module_looper.gptq_model.__class__)
219-
220-
if self.moe_block is not None:
221-
# Save original forward method
222-
self.moe_forward_original = self.moe_block.forward
223-
224-
# Create wrapper that forwards to all experts
225-
moe_block_prefix = hooks._extract_moe_block_prefix(self.current_subset, self.moe_block)
226-
227-
def moe_forward_wrapper(hidden_states, **kwargs):
228-
return hooks.forward_to_all_experts(
229-
moe_block=self.moe_block,
230-
hidden_states=hidden_states,
231-
processor=self.processor,
232-
subset=self.current_subset,
233-
original_forward=self.moe_forward_original,
234-
model_class=self.module_looper.gptq_model.__class__,
235-
module_looper=self.module_looper, # Pass for TLS-based hooks pausing
236-
moe_block_prefix=moe_block_prefix,
237-
replica_module=self.module, # Pass replica for device-correct module resolution
238-
**kwargs
239-
)
240-
241-
# Temporarily replace forward method
242-
self.moe_block.forward = moe_forward_wrapper
243-
self.moe_hooks_active = True
244-
245-
return self
246-
247-
def __exit__(self, exc_type, exc_val, exc_tb):
248-
"""Restore original MoE forward method if it was patched."""
249-
if self.moe_hooks_active and self.moe_forward_original is not None and self.moe_block is not None:
250-
self.moe_block.forward = self.moe_forward_original
251-
return False # Don't suppress exceptions
252-
253186
def register_layer_callback(self, callback) -> None:
254187
"""Register or replace the layer-complete callback target."""
255188
self._layer_callback = callback
@@ -422,22 +355,6 @@ def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local:
422355
setattr(processor, "_mask_tls", tls)
423356
return tls
424357

425-
def _processor_hooks_paused_tls(self, processor):
426-
"""Get or create thread-local storage for hooks_paused flag."""
427-
if not hasattr(processor, "_hooks_paused_tls"):
428-
processor._hooks_paused_tls = threading.local()
429-
return processor._hooks_paused_tls
430-
431-
def _set_processor_hooks_paused(self, processor: LoopProcessor, paused: bool):
432-
"""Set hooks paused state for current thread."""
433-
tls = self._processor_hooks_paused_tls(processor)
434-
tls.value = paused
435-
436-
def _get_processor_hooks_paused(self, processor: LoopProcessor) -> bool:
437-
"""Get hooks paused state for current thread (thread-safe)."""
438-
tls = getattr(processor, "_hooks_paused_tls", None)
439-
return getattr(tls, "value", False) if tls else False
440-
441358
def _set_processor_mask(self, processor: LoopProcessor, mask):
442359
tls = self._processor_mask_tls(processor)
443360
tls.value = mask
@@ -543,10 +460,6 @@ def _extract_moe_group_key(self, module_name: Optional[str]) -> Optional[str]:
543460
prefix, _ = module_name.split(".shared_experts.", 1)
544461
return f"{prefix}.shared_experts"
545462

546-
if ".shared_expert." in module_name:
547-
prefix, _ = module_name.split(".shared_expert.", 1)
548-
return f"{prefix}.shared_expert"
549-
550463
return None
551464

552465
def _is_attention_module_name(self, module_name: str) -> bool:
@@ -561,43 +474,6 @@ def _is_attention_module_name(self, module_name: str) -> bool:
561474
if lowered.endswith("attn") or lowered.endswith("attention"):
562475
return True
563476
return False
564-
565-
def _should_use_moe_lifecycle(self, module: nn.Module, processor: LoopProcessor) -> bool:
566-
"""
567-
Check if MoE lifecycle hooks should be used for this module.
568-
569-
Returns True if:
570-
- pass_whole_dataset_to_each_expert flag is enabled
571-
- Model has lifecycle hooks configured
572-
- Module contains an MoE block
573-
"""
574-
# Check if feature is enabled
575-
moe_routing_bypass = getattr(self.gptq_model.quantize_config, "moe_routing_bypass", None)
576-
flag_enabled = moe_routing_bypass() if callable(moe_routing_bypass) else False
577-
if not flag_enabled:
578-
return False
579-
580-
# Check if model has lifecycle hooks
581-
hooks = getattr(self.gptq_model, 'moe_lifecycle_hooks', None)
582-
if hooks is None:
583-
log.warn(
584-
f"pass_whole_dataset_to_each_expert is enabled but {self.gptq_model.__class__.__name__} "
585-
f"model does not have 'moe_lifecycle_hooks' configured. MoE optimization will be disabled. "
586-
f"Please ensure your model definition has proper MoE lifecycle hooks configured."
587-
)
588-
return False
589-
590-
# Check if this module contains an MoE block
591-
moe_block = hooks.get_moe_block(module, self.gptq_model.__class__)
592-
if moe_block is None:
593-
log.warn(
594-
f"pass_whole_dataset_to_each_expert is enabled but no MoE block found in module "
595-
f"{module.__class__.__name__}. MoE optimization will be disabled for this module. "
596-
f"This may indicate an issue with the model's MoE configuration or module structure."
597-
)
598-
return False
599-
600-
return True
601477

602478
def _assign_quant_device_for_module(
603479
self,
@@ -910,8 +786,6 @@ def _run_forward_batches_single(
910786
if attn_tensor is not None:
911787
seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None
912788
keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len)
913-
914-
# Set mask using TLS (thread-safe)
915789
self._set_processor_mask(processor, keep_mask)
916790

917791
additional_inputs: Dict[str, torch.Tensor] = {}
@@ -938,8 +812,8 @@ def _run_forward_batches_single(
938812
if not preserve_module_devices:
939813
rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)
940814

941-
# MoE lifecycle hooks integration - using context manager
942-
with self.MoERoutingOverrideContext(module, self.moe_routing_override) if self.moe_routing_override else self.MoELifecycleContext(self, module, processor, self._current_subset):
815+
# MoE lifecycle hooks integration - using context manager
816+
with self.MoERoutingOverrideContext(module, self.moe_routing_override) if self.moe_routing_override else nullcontext:
943817
module_output = None
944818
try:
945819
if is_lm_head_module:
@@ -951,12 +825,6 @@ def _run_forward_batches_single(
951825
finally:
952826
self._set_processor_mask(processor, None)
953827

954-
# Release intermediate tensors promptly after they are no longer needed
955-
del layer_input
956-
del attn_tensor
957-
del keep_mask
958-
del additional_inputs
959-
960828
if (
961829
reuse_kv
962830
and module_output is not None
@@ -971,10 +839,6 @@ def _run_forward_batches_single(
971839
primary = move_to(primary, device=cur_layer_device)
972840
outputs.append([primary])
973841

974-
# Release module_output promptly after extracting what we need
975-
if module_output is not None:
976-
del module_output
977-
978842
rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0
979843
if rows_for_batch <= 0:
980844
rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1
@@ -1077,8 +941,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
1077941

1078942
# Ensure any async replication/memcpy ops are complete before threads start fanning out.
1079943
torch_sync()
1080-
1081-
# Clone modules FIRST, then apply MoE lifecycle hooks to all replicas
944+
1082945
try:
1083946
module_replicas = clone_module_for_devices(
1084947
module,
@@ -1101,8 +964,8 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
1101964
ctx = None
1102965
if self.moe_routing_override:
1103966
ctx = self.MoERoutingOverrideContext(replica, self.moe_routing_override)
1104-
elif self._should_use_moe_lifecycle(module, processor):
1105-
ctx = self.MoELifecycleContext(self, replica, processor, self._current_subset)
967+
else:
968+
ctx = nullcontext
1106969

1107970
if ctx:
1108971
ctx.__enter__()
@@ -1221,7 +1084,7 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
12211084
# ensure replicas release promptly and free GPU memory
12221085
for dev in list(module_replicas.keys()):
12231086
del module_replicas[dev]
1224-
1087+
12251088
if not need_outputs:
12261089
return []
12271090

@@ -1243,10 +1106,6 @@ def _replica_progress(idx: int, total: int, device: torch.device, step: str) ->
12431106

12441107
def _masked_hook_wrapper(self, processor: LoopProcessor, inner_hook, hook_source: str):
12451108
def hook(module, inputs, output):
1246-
# Thread-safe check if hooks are paused (TLS-based, per-thread)
1247-
if self._get_processor_hooks_paused(processor):
1248-
return
1249-
12501109
keep = self._get_processor_mask(processor)
12511110

12521111
timer = getattr(self.gptq_model, "quant_region_timer", None)
@@ -1292,51 +1151,6 @@ def hook(module, inputs, output):
12921151
source=hook_source,
12931152
)
12941153
return hook
1295-
1296-
def _masked_pre_hook_wrapper(self, processor: LoopProcessor, inner_hook, hook_source: str):
1297-
"""
1298-
Pre-forward hook wrapper for MoE expert modules.
1299-
This is called BEFORE forward executes (when used with HookedLinear.forward_hook).
1300-
Respects hooks_paused state to avoid double-counting during intermediate calculations.
1301-
"""
1302-
def pre_hook(module, inputs, output):
1303-
# Thread-safe check if hooks are paused (TLS-based, per-thread)
1304-
if self._get_processor_hooks_paused(processor):
1305-
return
1306-
1307-
# Get mask using TLS (thread-safe)
1308-
keep = self._get_processor_mask(processor)
1309-
1310-
timer = getattr(self.gptq_model, "quant_region_timer", None)
1311-
start = time.perf_counter() if timer else None
1312-
1313-
# Mask first tensor-like input if it's [B, S, ...]
1314-
new_inputs = inputs
1315-
try:
1316-
if isinstance(inputs, (tuple, list)) and len(inputs) > 0 and torch.is_tensor(inputs[0]):
1317-
x = inputs[0]
1318-
if keep is not None and x.dim() >= 3:
1319-
xk = apply_keep_mask_bt(x, keep)
1320-
if isinstance(inputs, tuple):
1321-
new_inputs = (xk,) + tuple(inputs[1:])
1322-
else:
1323-
new_inputs = [xk] + list(inputs[1:])
1324-
except Exception:
1325-
# Never break the forward due to masking
1326-
new_inputs = inputs
1327-
1328-
# Call inner hook with inputs and output (GPTQ ignores output anyway)
1329-
try:
1330-
inner_hook(module, new_inputs, output)
1331-
finally:
1332-
if timer is not None and start is not None:
1333-
timer.record(
1334-
"forward_pre_hook",
1335-
time.perf_counter() - start,
1336-
source=hook_source,
1337-
)
1338-
1339-
return pre_hook
13401154

13411155
def cache_inputs(self, layers, calibration_data, use_cache):
13421156
capture_stage = StageInputsCapture(self, logger=log)

gptqmodel/looper/named_module.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde
3434
# persistent work state for named module (used by some LoopProcessors)
3535
# store all `processed()` work state/data/result here
3636
self.state = {}
37-
38-
# Forward hook mechanism (compatible with HookedLinear)
39-
self.forward_hook = None
40-
self.forward_hook_last = False
4137

4238
# print(f"NamedModule init: name: `{name}, full-name: `{full_name}`")
4339

@@ -132,18 +128,6 @@ def __getattr__(self, name: str):
132128

133129
# setattr is always called by python even if attr exists in `self`
134130
def __setattr__(self, name: str, value: Any) -> None:
135-
# Proxy forward_hook to inner module if it supports it (e.g. HookedLinear)
136-
if name in ["forward_hook", "forward_hook_last"]:
137-
try:
138-
module = object.__getattribute__(self, "module")
139-
if hasattr(module, name):
140-
setattr(module, name, value)
141-
except AttributeError:
142-
pass # module not set yet during __init__
143-
# Also set on self for consistency
144-
object.__setattr__(self, name, value)
145-
return
146-
147131
if name in [
148132
"module",
149133
"module_dtype",
@@ -153,7 +137,6 @@ def __setattr__(self, name: str, value: Any) -> None:
153137
"state",
154138
"_parent_lock",
155139
"target_device",
156-
"target_device_stream",
157140
"register_buffer",
158141
"unregister_buffer",
159142
"register_parameter",
@@ -172,26 +155,6 @@ def __setattr__(self, name: str, value: Any) -> None:
172155
else:
173156
with lock:
174157
setattr(module, name, value)
175-
176-
def forward(self, *args, **kwargs):
177-
"""Forward pass with optional hook support (compatible with HookedLinear)."""
178-
output = self.module(*args, **kwargs)
179-
180-
# Call forward_hook if it exists and wasn't proxied to inner module
181-
if self.forward_hook:
182-
# Check if inner module has forward_hook (meaning we proxied it)
183-
if not hasattr(self.module, 'forward_hook') or self.module.forward_hook is None:
184-
# Extract first positional arg as input for hook
185-
input_tensor = args[0] if args else None
186-
self.forward_hook(self, (input_tensor,), output)
187-
188-
# If forward_hook_last is True, this should stop execution (like HookedLinear)
189-
# The hook may raise StopForward, which should propagate
190-
if self.forward_hook_last:
191-
from ..nn_modules.hooked_linear import StopForward # Local import to avoid circular dependency
192-
raise StopForward()
193-
194-
return output
195158

196159
def stream_state_payload_to_cpu(
197160
self,

0 commit comments

Comments
 (0)