Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Test refactor_factored_attn_matrices with TransformerBridge.

Verifies that the refactored attention matrices produce correct results when
used via TransformerBridge, matching HookedTransformer output.
"""

import pytest
import torch

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge import TransformerBridge


@pytest.fixture(scope="module")
def model_name():
return "distilgpt2"


@pytest.fixture(scope="module")
def device():
return "cpu"


@pytest.fixture(scope="module")
def test_text():
return "Natural language processing"


@pytest.fixture(scope="module")
def reference_ht(model_name, device):
"""HookedTransformer with refactor_factored_attn_matrices=True."""
return HookedTransformer.from_pretrained(
model_name,
device=device,
refactor_factored_attn_matrices=True,
)


def test_refactor_factored_attn_matrices_loss_matches(model_name, device, test_text, reference_ht):
"""Bridge with refactor_factored_attn_matrices should match HookedTransformer."""
ref_loss = reference_ht(test_text, return_type="loss")

bridge = TransformerBridge.boot_transformers(model_name, device=device)
bridge.enable_compatibility_mode(refactor_factored_attn_matrices=True)
bridge_loss = bridge(test_text, return_type="loss")

assert not torch.isnan(bridge_loss), "Bridge produced NaN loss"
assert not torch.isinf(bridge_loss), "Bridge produced infinite loss"

loss_diff = abs(bridge_loss.item() - ref_loss.item())
assert loss_diff < 1.0, (
f"Loss difference too large: {loss_diff:.6f} "
f"(bridge={bridge_loss.item():.4f}, reference={ref_loss.item():.4f})"
)


def test_refactor_factored_attn_matrices_logits_match(model_name, device, test_text, reference_ht):
"""Bridge logits should closely match HookedTransformer logits after refactoring."""
tokens = reference_ht.to_tokens(test_text)
ref_logits = reference_ht(tokens)

bridge = TransformerBridge.boot_transformers(model_name, device=device)
bridge.enable_compatibility_mode(refactor_factored_attn_matrices=True)
bridge_logits = bridge(tokens)

# Check shapes match
assert (
ref_logits.shape == bridge_logits.shape
), f"Shape mismatch: ref={ref_logits.shape}, bridge={bridge_logits.shape}"

# Check values are close
max_diff = (ref_logits - bridge_logits).abs().max().item()
assert max_diff < 1.0, f"Max logit difference too large: {max_diff:.6f}"


def test_refactor_preserves_fold_ln(model_name, device, test_text):
"""Refactoring should not undo fold_ln — both should be applied together."""
# Reference: fold_ln=True + refactor=True
ref = HookedTransformer.from_pretrained(
model_name,
device=device,
fold_ln=True,
refactor_factored_attn_matrices=True,
)
ref_loss = ref(test_text, return_type="loss")

# Bridge: same settings
bridge = TransformerBridge.boot_transformers(model_name, device=device)
bridge.enable_compatibility_mode(
fold_ln=True,
refactor_factored_attn_matrices=True,
)
bridge_loss = bridge(test_text, return_type="loss")

loss_diff = abs(bridge_loss.item() - ref_loss.item())
assert loss_diff < 1.0, (
f"fold_ln + refactor mismatch: {loss_diff:.6f} "
f"(bridge={bridge_loss.item():.4f}, ref={ref_loss.item():.4f})"
)
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,26 @@ def __init__(
self._reference_model: Optional[Any] = None
self._layer_idx: Optional[int] = None

# After splitting, the q/k/v LinearBridges hold the authoritative weights.
# The original qkv LinearBridge remains registered in _modules (so
# self.qkv is still accessible) but its parameters are stale copies of
# the pre-split combined weight. This hook excludes them from state_dict
# so weight processing steps never read unprocessed combined weights.
self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict)

@staticmethod
def _filter_qkv_state_dict(
module: torch.nn.Module,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
) -> None:
"""State dict hook that removes stale combined QKV entries."""
qkv_prefix = prefix + "qkv."
keys_to_remove = [k for k in state_dict if k.startswith(qkv_prefix)]
for k in keys_to_remove:
del state_dict[k]

def _create_qkv_conversion_rule(self) -> BaseTensorConversion:
"""Create the appropriate conversion rule for the individual q, k, and v matrices.

Expand Down
48 changes: 41 additions & 7 deletions transformer_lens/model_bridge/supported_architectures/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,74 @@


class QKVSplitRearrangeConversion(BaseTensorConversion):
"""Custom conversion that splits QKV tensor and then rearranges."""
"""Custom conversion that splits QKV tensor and then rearranges.

Handles two input formats:
- Combined QKV tensor (from HuggingFace): one dimension is ~3x the other.
Splits into Q/K/V parts, then rearranges to TL format.
- Already-split tensor (from bridge state dict): nn.Linear format
[n_heads*d_head, d_model]. Rearranges directly to TL format.
"""

def __init__(self, qkv_index: int, rearrange_pattern: str, **axes_lengths):
"""Initialize the conversion.

Args:
qkv_index: Index of Q (0), K (1), or V (2) in the QKV tensor
rearrange_pattern: Einops pattern for rearrangement
rearrange_pattern: Einops pattern for rearrangement (Conv1D format)
**axes_lengths: Additional axes lengths for einops
"""
super().__init__()
self.qkv_index = qkv_index
self.rearrange_pattern = rearrange_pattern
self.axes_lengths = axes_lengths

def _is_combined_qkv(self, tensor: torch.Tensor) -> bool:
"""Check if a tensor is a combined QKV tensor vs already-split."""
if tensor.ndim == 2:
d0, d1 = tensor.shape
return d1 > d0 * 2 or d0 > d1 * 2
if tensor.ndim == 1:
n = self.axes_lengths.get("n", 1)
# Combined bias has 3x the expected individual size
return tensor.shape[0] % 3 == 0 and tensor.shape[0] > n * 3
return False

def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
"""Split QKV tensor and rearrange the selected part."""
# Determine the split dimension based on tensor shape
if not self._is_combined_qkv(input_value):
# Already-split tensor in nn.Linear format [n_heads*d_head, d_model].
# The original rearrange_pattern is "d_model (n h) -> n d_model h"
# (Conv1D format). For nn.Linear format, the dims are transposed:
return einops.rearrange(
input_value, "(n h) d_model -> n d_model h", **self.axes_lengths
)

# Combined QKV tensor — split then rearrange
if len(input_value.shape) == 2:
# Weight tensor: [d_model, 3*d_model] -> split along dim=1
split_dim = 1
split_dim = 1 if input_value.shape[1] > input_value.shape[0] else 0
elif len(input_value.shape) == 1:
# Bias tensor: [3*n_heads*d_head] -> split along dim=0
split_dim = 0
else:
raise ValueError(f"Unexpected tensor shape: {input_value.shape}")

# Split the QKV tensor
qkv_parts = torch.tensor_split(input_value, 3, dim=split_dim)
selected_part = qkv_parts[self.qkv_index]

# Apply rearrangement
return einops.rearrange(selected_part, self.rearrange_pattern, **self.axes_lengths)

def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
"""Revert from TL format [n_heads, d_model, d_head] to nn.Linear format."""
if input_value.ndim == 3:
return einops.rearrange(
input_value, "n d_model h -> (n h) d_model", **self.axes_lengths
)
if input_value.ndim == 2:
# Bias in TL format [n_heads, d_head] -> [n_heads*d_head]
return einops.rearrange(input_value, "n h -> (n h)", **self.axes_lengths)
return input_value


class GPT2ArchitectureAdapter(ArchitectureAdapter):
"""Architecture adapter for GPT2 models.
Expand Down