From 0281626c189b937740569ca5d5b4cef039d32c88 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 2 Mar 2026 16:58:06 -0600 Subject: [PATCH 1/3] Fixed bug where stale joint QKV is being used instead of the correct split weights --- .../test_refactor_factored_attn_matrices.py | 103 ++++++++++++++++++ .../joint_qkv_attention.py | 13 +++ .../supported_architectures/gpt2.py | 50 +++++++-- 3 files changed, 159 insertions(+), 7 deletions(-) create mode 100644 tests/integration/model_bridge/test_refactor_factored_attn_matrices.py diff --git a/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py b/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py new file mode 100644 index 000000000..522920d64 --- /dev/null +++ b/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py @@ -0,0 +1,103 @@ +"""Test refactor_factored_attn_matrices with TransformerBridge. + +Verifies that the refactored attention matrices produce correct results when +used via TransformerBridge, matching HookedTransformer output. +""" + +import pytest +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.model_bridge import TransformerBridge + + +@pytest.fixture(scope="module") +def model_name(): + return "distilgpt2" + + +@pytest.fixture(scope="module") +def device(): + return "cpu" + + +@pytest.fixture(scope="module") +def test_text(): + return "Natural language processing" + + +@pytest.fixture(scope="module") +def reference_ht(model_name, device): + """HookedTransformer with refactor_factored_attn_matrices=True.""" + return HookedTransformer.from_pretrained( + model_name, + device=device, + refactor_factored_attn_matrices=True, + ) + + +def test_refactor_factored_attn_matrices_loss_matches( + model_name, device, test_text, reference_ht +): + """Bridge with refactor_factored_attn_matrices should match HookedTransformer.""" + ref_loss = reference_ht(test_text, return_type="loss") + + bridge = TransformerBridge.boot_transformers(model_name, device=device) + bridge.enable_compatibility_mode(refactor_factored_attn_matrices=True) + bridge_loss = bridge(test_text, return_type="loss") + + assert not torch.isnan(bridge_loss), "Bridge produced NaN loss" + assert not torch.isinf(bridge_loss), "Bridge produced infinite loss" + + loss_diff = abs(bridge_loss.item() - ref_loss.item()) + assert loss_diff < 1.0, ( + f"Loss difference too large: {loss_diff:.6f} " + f"(bridge={bridge_loss.item():.4f}, reference={ref_loss.item():.4f})" + ) + + +def test_refactor_factored_attn_matrices_logits_match( + model_name, device, test_text, reference_ht +): + """Bridge logits should closely match HookedTransformer logits after refactoring.""" + tokens = reference_ht.to_tokens(test_text) + ref_logits = reference_ht(tokens) + + bridge = TransformerBridge.boot_transformers(model_name, device=device) + bridge.enable_compatibility_mode(refactor_factored_attn_matrices=True) + bridge_logits = bridge(tokens) + + # Check shapes match + assert ref_logits.shape == bridge_logits.shape, ( + f"Shape mismatch: ref={ref_logits.shape}, bridge={bridge_logits.shape}" + ) + + # Check values are close + max_diff = (ref_logits - bridge_logits).abs().max().item() + assert max_diff < 1.0, f"Max logit difference too large: {max_diff:.6f}" + + +def test_refactor_preserves_fold_ln(model_name, device, test_text): + """Refactoring should not undo fold_ln — both should be applied together.""" + # Reference: fold_ln=True + refactor=True + ref = HookedTransformer.from_pretrained( + model_name, + device=device, + fold_ln=True, + refactor_factored_attn_matrices=True, + ) + ref_loss = ref(test_text, return_type="loss") + + # Bridge: same settings + bridge = TransformerBridge.boot_transformers(model_name, device=device) + bridge.enable_compatibility_mode( + fold_ln=True, + refactor_factored_attn_matrices=True, + ) + bridge_loss = bridge(test_text, return_type="loss") + + loss_diff = abs(bridge_loss.item() - ref_loss.item()) + assert loss_diff < 1.0, ( + f"fold_ln + refactor mismatch: {loss_diff:.6f} " + f"(bridge={bridge_loss.item():.4f}, ref={ref_loss.item():.4f})" + ) diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index ce2eff8d8..f88c2dc2b 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -105,6 +105,19 @@ def __init__( self._reference_model: Optional[Any] = None self._layer_idx: Optional[int] = None + def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Return state dict excluding stale combined QKV entries. + + After splitting, the q/k/v LinearBridges hold the authoritative weights. + The original qkv LinearBridge remains registered in _modules (so + self.qkv is still accessible) but its parameters are stale copies of + the pre-split combined weight. Excluding them prevents weight processing + steps from accidentally reading unprocessed combined weights. + """ + sd = super().state_dict(*args, **kwargs) + qkv_prefix = "qkv." + return {k: v for k, v in sd.items() if not k.startswith(qkv_prefix)} + def _create_qkv_conversion_rule(self) -> BaseTensorConversion: """Create the appropriate conversion rule for the individual q, k, and v matrices. diff --git a/transformer_lens/model_bridge/supported_architectures/gpt2.py b/transformer_lens/model_bridge/supported_architectures/gpt2.py index 88b629ac7..a6faa7fcc 100644 --- a/transformer_lens/model_bridge/supported_architectures/gpt2.py +++ b/transformer_lens/model_bridge/supported_architectures/gpt2.py @@ -27,14 +27,21 @@ class QKVSplitRearrangeConversion(BaseTensorConversion): - """Custom conversion that splits QKV tensor and then rearranges.""" + """Custom conversion that splits QKV tensor and then rearranges. + + Handles two input formats: + - Combined QKV tensor (from HuggingFace): one dimension is ~3x the other. + Splits into Q/K/V parts, then rearranges to TL format. + - Already-split tensor (from bridge state dict): nn.Linear format + [n_heads*d_head, d_model]. Rearranges directly to TL format. + """ def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths): """Initialize the conversion. Args: qkv_index: Index of Q (0), K (1), or V (2) in the QKV tensor - rearrange_pattern: Einops pattern for rearrangement + rearrange_pattern: Einops pattern for rearrangement (Conv1D format) **axes_lengths: Additional axes lengths for einops """ super().__init__() @@ -42,25 +49,54 @@ def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths): self.rearrange_pattern = rearrange_pattern self.axes_lengths = axes_lengths + def _is_combined_qkv(self, tensor: torch.Tensor) -> bool: + """Check if a tensor is a combined QKV tensor vs already-split.""" + if tensor.ndim == 2: + d0, d1 = tensor.shape + return d1 > d0 * 2 or d0 > d1 * 2 + if tensor.ndim == 1: + n = self.axes_lengths.get("n", 1) + # Combined bias has 3x the expected individual size + return tensor.shape[0] % 3 == 0 and tensor.shape[0] > n * 3 + return False + def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: """Split QKV tensor and rearrange the selected part.""" - # Determine the split dimension based on tensor shape + if not self._is_combined_qkv(input_value): + # Already-split tensor in nn.Linear format [n_heads*d_head, d_model]. + # The original rearrange_pattern is "d_model (n h) -> n d_model h" + # (Conv1D format). For nn.Linear format, the dims are transposed: + return einops.rearrange( + input_value, "(n h) d_model -> n d_model h", **self.axes_lengths + ) + + # Combined QKV tensor — split then rearrange if len(input_value.shape) == 2: # Weight tensor: [d_model, 3*d_model] -> split along dim=1 - split_dim = 1 + split_dim = 1 if input_value.shape[1] > input_value.shape[0] else 0 elif len(input_value.shape) == 1: # Bias tensor: [3*n_heads*d_head] -> split along dim=0 split_dim = 0 else: raise ValueError(f"Unexpected tensor shape: {input_value.shape}") - # Split the QKV tensor qkv_parts = torch.tensor_split(input_value, 3, dim=split_dim) selected_part = qkv_parts[self.qkv_index] - - # Apply rearrangement return einops.rearrange(selected_part, self.rearrange_pattern, **self.axes_lengths) + def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: + """Revert from TL format [n_heads, d_model, d_head] to nn.Linear format.""" + if input_value.ndim == 3: + return einops.rearrange( + input_value, "n d_model h -> (n h) d_model", **self.axes_lengths + ) + if input_value.ndim == 2: + # Bias in TL format [n_heads, d_head] -> [n_heads*d_head] + return einops.rearrange( + input_value, "n h -> (n h)", **self.axes_lengths + ) + return input_value + class GPT2ArchitectureAdapter(ArchitectureAdapter): """Architecture adapter for GPT2 models. From 2f5987e1490270779b52d921d6597e55982d74bf Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 2 Mar 2026 17:35:22 -0600 Subject: [PATCH 2/3] Format fixes --- .../test_refactor_factored_attn_matrices.py | 14 +++++--------- .../generalized_components/joint_qkv_attention.py | 11 +++++++++-- .../model_bridge/supported_architectures/gpt2.py | 4 +--- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py b/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py index 522920d64..4963ea650 100644 --- a/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py +++ b/tests/integration/model_bridge/test_refactor_factored_attn_matrices.py @@ -36,9 +36,7 @@ def reference_ht(model_name, device): ) -def test_refactor_factored_attn_matrices_loss_matches( - model_name, device, test_text, reference_ht -): +def test_refactor_factored_attn_matrices_loss_matches(model_name, device, test_text, reference_ht): """Bridge with refactor_factored_attn_matrices should match HookedTransformer.""" ref_loss = reference_ht(test_text, return_type="loss") @@ -56,9 +54,7 @@ def test_refactor_factored_attn_matrices_loss_matches( ) -def test_refactor_factored_attn_matrices_logits_match( - model_name, device, test_text, reference_ht -): +def test_refactor_factored_attn_matrices_logits_match(model_name, device, test_text, reference_ht): """Bridge logits should closely match HookedTransformer logits after refactoring.""" tokens = reference_ht.to_tokens(test_text) ref_logits = reference_ht(tokens) @@ -68,9 +64,9 @@ def test_refactor_factored_attn_matrices_logits_match( bridge_logits = bridge(tokens) # Check shapes match - assert ref_logits.shape == bridge_logits.shape, ( - f"Shape mismatch: ref={ref_logits.shape}, bridge={bridge_logits.shape}" - ) + assert ( + ref_logits.shape == bridge_logits.shape + ), f"Shape mismatch: ref={ref_logits.shape}, bridge={bridge_logits.shape}" # Check values are close max_diff = (ref_logits - bridge_logits).abs().max().item() diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index f88c2dc2b..ab062ca04 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -115,8 +115,15 @@ def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: steps from accidentally reading unprocessed combined weights. """ sd = super().state_dict(*args, **kwargs) - qkv_prefix = "qkv." - return {k: v for k, v in sd.items() if not k.startswith(qkv_prefix)} + # PyTorch passes a `prefix` kwarg (e.g. "blocks.0.attn.") when + # collecting state dicts from child modules into a shared destination. + # Use it to build the fully-qualified qkv prefix to delete. + prefix = kwargs.get("prefix", "") + qkv_prefix = prefix + "qkv." + keys_to_remove = [k for k in sd if k.startswith(qkv_prefix)] + for k in keys_to_remove: + del sd[k] + return sd def _create_qkv_conversion_rule(self) -> BaseTensorConversion: """Create the appropriate conversion rule for the individual q, k, and v matrices. diff --git a/transformer_lens/model_bridge/supported_architectures/gpt2.py b/transformer_lens/model_bridge/supported_architectures/gpt2.py index a6faa7fcc..2fb5acc41 100644 --- a/transformer_lens/model_bridge/supported_architectures/gpt2.py +++ b/transformer_lens/model_bridge/supported_architectures/gpt2.py @@ -92,9 +92,7 @@ def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: ) if input_value.ndim == 2: # Bias in TL format [n_heads, d_head] -> [n_heads*d_head] - return einops.rearrange( - input_value, "n h -> (n h)", **self.axes_lengths - ) + return einops.rearrange(input_value, "n h -> (n h)", **self.axes_lengths) return input_value From bc36c853b4dffdb84d1a16a9724176a86fbc8ec2 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Mon, 2 Mar 2026 19:24:15 -0600 Subject: [PATCH 3/3] Fixing typing issues --- .../joint_qkv_attention.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index ab062ca04..3266920aa 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -105,25 +105,25 @@ def __init__( self._reference_model: Optional[Any] = None self._layer_idx: Optional[int] = None - def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: - """Return state dict excluding stale combined QKV entries. - - After splitting, the q/k/v LinearBridges hold the authoritative weights. - The original qkv LinearBridge remains registered in _modules (so - self.qkv is still accessible) but its parameters are stale copies of - the pre-split combined weight. Excluding them prevents weight processing - steps from accidentally reading unprocessed combined weights. - """ - sd = super().state_dict(*args, **kwargs) - # PyTorch passes a `prefix` kwarg (e.g. "blocks.0.attn.") when - # collecting state dicts from child modules into a shared destination. - # Use it to build the fully-qualified qkv prefix to delete. - prefix = kwargs.get("prefix", "") + # After splitting, the q/k/v LinearBridges hold the authoritative weights. + # The original qkv LinearBridge remains registered in _modules (so + # self.qkv is still accessible) but its parameters are stale copies of + # the pre-split combined weight. This hook excludes them from state_dict + # so weight processing steps never read unprocessed combined weights. + self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict) + + @staticmethod + def _filter_qkv_state_dict( + module: torch.nn.Module, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + ) -> None: + """State dict hook that removes stale combined QKV entries.""" qkv_prefix = prefix + "qkv." - keys_to_remove = [k for k in sd if k.startswith(qkv_prefix)] + keys_to_remove = [k for k in state_dict if k.startswith(qkv_prefix)] for k in keys_to_remove: - del sd[k] - return sd + del state_dict[k] def _create_qkv_conversion_rule(self) -> BaseTensorConversion: """Create the appropriate conversion rule for the individual q, k, and v matrices.