From 07fd8c97431ec85ffbce2afd2b7cfc503d1257b7 Mon Sep 17 00:00:00 2001 From: degenfabian Date: Mon, 18 Aug 2025 17:59:45 +0200 Subject: [PATCH 1/7] updated loading in attribution patching demo to use transformer bridge --- .github/workflows/checks.yml | 2 +- demos/Attribution_Patching_Demo.ipynb | 3761 ++++++++++++++++++++++++- 2 files changed, 3761 insertions(+), 2 deletions(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 4de51026c..8820352c5 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -145,7 +145,7 @@ jobs: matrix: notebook: # - "Activation_Patching_in_TL_Demo" - # - "Attribution_Patching_Demo" + - "Attribution_Patching_Demo" - "ARENA_Content" - "Colab_Compatibility" - "BERT" diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index 2862fb9c8..722e18db4 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -1 +1,3760 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]},{"name":"stderr","output_type":"stream","text":["/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:24: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"load_ext autoreload\")\n","/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:25: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"autoreload 2\")\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","import os\n","\n","DEBUG_MODE = False\n","IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n","try:\n"," import google.colab\n","\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")\n","\n","if IN_COLAB or IN_GITHUB:\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," %pip install circuitsvis\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'torchtyping'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[3], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexpress\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorType \u001b[38;5;28;01mas\u001b[39;00m TT\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Union, Optional, Callable\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m partial\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchtyping'"]}],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import (\n"," HookedTransformer,\n"," HookedTransformerConfig,\n"," FactoredMatrix,\n"," ActivationCache,\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = [\n"," \"When John and Mary went to the shops, John gave the bag to\",\n"," \"When John and Mary went to the shops, Mary gave the bag to\",\n"," \"When Tom and James went to the park, James gave the ball to\",\n"," \"When Tom and James went to the park, Tom gave the ball to\",\n"," \"When Dan and Sid went to the shops, Sid gave an apple to\",\n"," \"When Dan and Sid went to the shops, Dan gave an apple to\",\n"," \"After Martin and Amy went to the park, Amy gave a drink to\",\n"," \"After Martin and Amy went to the park, Martin gave a drink to\",\n","]\n","answers = [\n"," (\" Mary\", \" John\"),\n"," (\" John\", \" Mary\"),\n"," (\" Tom\", \" James\"),\n"," (\" James\", \" Tom\"),\n"," (\" Dan\", \" Sid\"),\n"," (\" Sid\", \" Dan\"),\n"," (\" Martin\", \" Amy\"),\n"," (\" Amy\", \" Martin\"),\n","]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]\n","]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor(\n"," [\n"," [model.to_single_token(answers[i][j]) for j in range(2)]\n"," for i in range(len(answers))\n"," ],\n"," device=model.cfg.device,\n",")\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape) == 3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","\n","\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (\n"," CLEAN_BASELINE - CORRUPTED_BASELINE\n"," )\n","\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","\n","\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n","\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n","\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return (\n"," value.item(),\n"," ActivationCache(cache, model),\n"," ActivationCache(grad_cache, model),\n"," )\n","\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n"," model, clean_tokens, ioi_metric\n",")\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(\n"," model, corrupted_tokens, ioi_metric\n",")\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(\n"," clean_cache, clean_grad_cache\n",") -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack(\n"," [clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_grad_stack = torch.stack(\n"," [clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(\n"," attention_attr,\n"," \"layer batch head_index dest src -> batch layer head_index dest src\",\n"," )\n"," return attention_attr\n","\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [\n"," f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n","]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [\n"," f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]\n","]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape) == 2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape) == 5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = -attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(\n"," attention_attr_signed,\n"," \"sign layer head_index dest src -> (layer head_index sign) dest src\",\n"," )\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = (\n"," attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," )\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title:\n"," display(Markdown(\"### \" + title))\n"," display(\n"," pysvelte.AttentionMulti(\n"," tokens=model.to_str_tokens(tokens),\n"," attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k],\n"," head_labels=head_labels[:top_k],\n"," )\n"," )\n","\n","\n","plot_attention_attr(\n"," attention_attr,\n"," clean_tokens,\n"," index=0,\n"," title=\"Attention Attribution for first sequence\",\n",")\n","\n","plot_attention_attr(\n"," attention_attr.sum(0),\n"," clean_tokens[0],\n"," title=\"Summed Attention Attribution for all sequences\",\n",")\n","print(\n"," \"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\"\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=True\n"," )\n"," corrupted_residual = corrupted_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return residual_attr, residual_labels\n","\n","\n","residual_attr, residual_labels = attr_patch_residual(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," residual_attr,\n"," y=residual_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Residual Attribution Patching\",\n",")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(\n"," -1, return_labels=False\n"," )\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return layer_out_attr, labels\n","\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," layer_out_attr,\n"," y=layer_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Layer Output Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(\n"," -1, return_labels=False\n"," )\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return head_out_attr, labels\n","\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_out_attr,\n"," y=head_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Head Output Attribution Patching\",\n",")\n","sum_head_out_attr = einops.reduce(\n"," head_out_attr,\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n",")\n","imshow(\n"," sum_head_out_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=\"Head Output Attribution Patching Sum Over Pos\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_vector_from_cache(\n"," cache, activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n",") -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack(\n"," [cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\",\n"," )\n"," return stacked_head_vectors\n","\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(\n"," corrupted_cache, activation_name\n"," )\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(\n"," corrupted_grad_cache, activation_name\n"," )\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\",\n"," )\n"," return head_vector_attr, labels\n","\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [\n"," (\"k\", \"Key\"),\n"," (\"q\", \"Query\"),\n"," (\"v\", \"Value\"),\n"," (\"z\", \"Mixed Value\"),\n","]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(\n"," clean_cache, corrupted_cache, corrupted_grad_cache, activation_name\n"," )\n"," imshow(\n"," head_vector_attr_dict[activation_name],\n"," y=head_vector_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=f\"{activation_name_full} Attribution Patching\",\n"," )\n"," sum_head_vector_attr = einops.reduce(\n"," head_vector_attr_dict[activation_name],\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," sum_head_vector_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=f\"{activation_name_full} Attribution Patching Sum Over Pos\",\n"," )"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_pattern_from_cache(\n"," cache,\n",") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack(\n"," [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n"," )\n"," return stacked_head_pattern\n","\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\",\n"," )\n"," return head_pattern_attr, labels\n","\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","\n","plot_attention_attr(\n"," einops.rearrange(\n"," head_pattern_attr,\n"," \"(layer head) dest src -> layer head dest src\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," ),\n"," clean_tokens,\n"," index=0,\n"," title=\"Head Pattern Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\n"," \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n"," vector_grad,\n"," ln_scales.squeeze(-1),\n"," W,\n"," )\n","\n","\n","def get_stacked_head_vector_grad_input(\n"," grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n"," for l in range(model.cfg.n_layers)\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def get_full_vector_grad_input(\n"," grad_cache,\n",") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_stacked_head_vector_grad_input(grad_cache, activation_name)\n"," for activation_name in [\"q\", \"k\", \"v\"]\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer=model.cfg.n_layers,\n"," head_index=model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n"," full_vector_grad_input,\n"," diff_head_result,\n"," )\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n"," > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n"," ).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_path_attr.sum(-1),\n"," y=end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3 * i + j)\n"," top_end_labels.append(end_labels[3 * i + j])\n","\n","imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," y=top_end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n"," y=top_end_labels[j::3],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n"," )"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," \"(head_end qkv) head_start -> qkv head_end head_start\",\n"," qkv=3,\n",")\n","imshow(\n"," top_head_path_attr,\n"," y=[i[:-1] for i in top_end_labels[::3]],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n"," facet_col=0,\n"," facet_labels=[\"Query\", \"Key\", \"Value\"],\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [\n"," 5 * model.cfg.n_heads + 5,\n"," 8 * model.cfg.n_heads + 6,\n"," 9 * model.cfg.n_heads + 9,\n","]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(\n"," all_paths,\n"," \"path_type (layer head) -> path_type layer head\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," all_paths,\n"," facet_col=0,\n"," facet_labels=[\n"," \"Query (In)\",\n"," \"Key (In)\",\n"," \"Value (In)\",\n"," \"Query (Out)\",\n"," \"Key (Out)\",\n"," \"Value (Out)\",\n"," ],\n"," title=f\"Input and Output Paths for head {label}\",\n"," yaxis=\"Layer\",\n"," xaxis=\"Head\",\n"," )"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n"," clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n"," )\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_block_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Activation Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack(\n"," [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n"," )\n"," return every_block_attr_patch_result\n","\n","\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(\n"," every_block_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_block_attr_patch_result.reshape(3, -1),\n"," x=every_block_act_patch_result.reshape(3, -1),\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution vs Activation Patching Per Block\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," hover=[\n"," f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n"," for l in range(model.cfg.n_layers)\n"," for p in range(context_length)\n"," ],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_head_all_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_all_pos_attr,\n"," head_q_all_pos_attr,\n"," head_k_all_pos_attr,\n"," head_v_all_pos_attr,\n"," head_pattern_all_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n"," attr_cache\n",")\n","imshow(\n"," every_head_all_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_all_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=head_out_labels,\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head)\",\n"," head=model.cfg.n_heads,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [\n"," f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n","]\n","imshow(\n"," clean_cache[\"pattern\", 5][:, 5],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L5H5\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 10][:, 7],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L10H7\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 11][:, 10],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L11H10\",\n"," facet_name=\"Prompt\",\n",")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","every_head_by_pos_act_patch_result = einops.rearrange(\n"," every_head_by_pos_act_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_by_pos_attr,\n"," head_q_by_pos_attr,\n"," head_k_by_pos_attr,\n"," head_v_by_pos_attr,\n"," head_pattern_by_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(\n"," every_head_by_pos_attr_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_by_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head pos)\",\n"," head=model.cfg.n_heads,\n"," pos=15,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","\n","\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n"," CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n"," )\n","\n","\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape) == 2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros(\n"," (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n"," )\n","\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n","\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(\n"," corrupted_tokens,\n"," fwd_hooks=[\n"," (\n"," f\"blocks.{layer}.hook_resid_pre\",\n"," partial(residual_hook, layer=layer, pos=pos),\n"," )\n"," ],\n"," )\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","\n","residual_act_patch = act_patch_residual(\n"," clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",")\n","\n","imshow(\n"," residual_act_patch,\n"," title=\"Factual Recall Patching (Residual)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," x=clean_str_tokens,\n",")"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.8"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " # Attribution Patching Demo\n", + " **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n", + " This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n", + "\n", + " The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n", + "\n", + " I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n", + "\n", + " **Tips for reading this Colab:**\n", + " * You can run all this code for yourself!\n", + " * The graphs are interactive!\n", + " * Use the table of contents pane in the sidebar to navigate\n", + " * Collapse irrelevant sections with the dropdown arrows\n", + " * Search the page using the search in the sidebar, not CTRL+F" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Setup (Ignore)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:24: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"load_ext autoreload\")\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:25: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"autoreload 2\")\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "DEBUG_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install torchtyping\n", + " # Install my janky personal plotting utils\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + " # Install another version of node that makes PySvelte work way faster\n", + " %pip install circuitsvis\n", + " # Needed for PySvelte to work, v3 came out and broke things...\n", + " %pip install typeguard==2.13.3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "\n", + "if IN_COLAB or not DEBUG_MODE:\n", + " # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torchtyping'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexpress\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorType \u001b[38;5;28;01mas\u001b[39;00m TT\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Union, Optional, Callable\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m partial\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchtyping'" + ] + } + ], + "source": [ + "# Import stuff\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import numpy as np\n", + "import einops\n", + "from fancy_einsum import einsum\n", + "import tqdm.notebook as tqdm\n", + "import random\n", + "from pathlib import Path\n", + "import plotly.express as px\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchtyping import TensorType as TT\n", + "from typing import List, Union, Optional, Callable\n", + "from functools import partial\n", + "import copy\n", + "import itertools\n", + "import json\n", + "\n", + "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n", + "import dataclasses\n", + "import datasets\n", + "from IPython.display import HTML, Markdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import transformer_lens\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens import (\n", + " ActivationCache,\n", + ")\n", + "from transformer_lens.model_bridge import TransformerBridge" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from neel_plotly import line, imshow, scatter" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import transformer_lens.patching as patching" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## IOI Patching Setup\n", + " This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gpt2-small into HookedTransformer\n" + ] + } + ], + "source": [ + "model = TransformerBridge.boot_transformers(\"gpt2\")\n", + "model.enable_compatibility_mode()\n", + "model.set_use_attn_result(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n", + "Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n", + "Answer token indices tensor([[ 5335, 1757],\n", + " [ 1757, 5335],\n", + " [ 4186, 3700],\n", + " [ 3700, 4186],\n", + " [ 6035, 15686],\n", + " [15686, 6035],\n", + " [ 5780, 14235],\n", + " [14235, 5780]], device='cuda:0')\n" + ] + } + ], + "source": [ + "prompts = [\n", + " \"When John and Mary went to the shops, John gave the bag to\",\n", + " \"When John and Mary went to the shops, Mary gave the bag to\",\n", + " \"When Tom and James went to the park, James gave the ball to\",\n", + " \"When Tom and James went to the park, Tom gave the ball to\",\n", + " \"When Dan and Sid went to the shops, Sid gave an apple to\",\n", + " \"When Dan and Sid went to the shops, Dan gave an apple to\",\n", + " \"After Martin and Amy went to the park, Amy gave a drink to\",\n", + " \"After Martin and Amy went to the park, Martin gave a drink to\",\n", + "]\n", + "answers = [\n", + " (\" Mary\", \" John\"),\n", + " (\" John\", \" Mary\"),\n", + " (\" Tom\", \" James\"),\n", + " (\" James\", \" Tom\"),\n", + " (\" Dan\", \" Sid\"),\n", + " (\" Sid\", \" Dan\"),\n", + " (\" Martin\", \" Amy\"),\n", + " (\" Amy\", \" Martin\"),\n", + "]\n", + "\n", + "clean_tokens = model.to_tokens(prompts)\n", + "# Swap each adjacent pair, with a hacky list comprehension\n", + "corrupted_tokens = clean_tokens[\n", + " [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]\n", + "]\n", + "print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n", + "print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n", + "\n", + "answer_token_indices = torch.tensor(\n", + " [\n", + " [model.to_single_token(answers[i][j]) for j in range(2)]\n", + " for i in range(len(answers))\n", + " ],\n", + " device=model.cfg.device,\n", + ")\n", + "print(\"Answer token indices\", answer_token_indices)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean logit diff: 3.5519\n", + "Corrupted logit diff: -3.5519\n" + ] + } + ], + "source": [ + "def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n", + " if len(logits.shape) == 3:\n", + " # Get final logits only\n", + " logits = logits[:, -1, :]\n", + " correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n", + " incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n", + " return (correct_logits - incorrect_logits).mean()\n", + "\n", + "\n", + "clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n", + "corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n", + "\n", + "clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n", + "print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n", + "\n", + "corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n", + "print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean Baseline is 1: 1.0000\n", + "Corrupted Baseline is 0: 0.0000\n" + ] + } + ], + "source": [ + "CLEAN_BASELINE = clean_logit_diff\n", + "CORRUPTED_BASELINE = corrupted_logit_diff\n", + "\n", + "\n", + "def ioi_metric(logits, answer_token_indices=answer_token_indices):\n", + " return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (\n", + " CLEAN_BASELINE - CORRUPTED_BASELINE\n", + " )\n", + "\n", + "\n", + "print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n", + "print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Patching\n", + " In the following cells, we define attribution patching and use it in various ways on the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean Value: 1.0\n", + "Clean Activations Cached: 220\n", + "Clean Gradients Cached: 220\n", + "Corrupted Value: 0.0\n", + "Corrupted Activations Cached: 220\n", + "Corrupted Gradients Cached: 220\n" + ] + } + ], + "source": [ + "filter_not_qkv_input = lambda name: \"_input\" not in name\n", + "\n", + "\n", + "def get_cache_fwd_and_bwd(model, tokens, metric):\n", + " model.reset_hooks()\n", + " cache = {}\n", + "\n", + " def forward_cache_hook(act, hook):\n", + " cache[hook.name] = act.detach()\n", + "\n", + " model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n", + "\n", + " grad_cache = {}\n", + "\n", + " def backward_cache_hook(act, hook):\n", + " grad_cache[hook.name] = act.detach()\n", + "\n", + " model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n", + "\n", + " value = metric(model(tokens))\n", + " value.backward()\n", + " model.reset_hooks()\n", + " return (\n", + " value.item(),\n", + " ActivationCache(cache, model),\n", + " ActivationCache(grad_cache, model),\n", + " )\n", + "\n", + "\n", + "clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n", + " model, clean_tokens, ioi_metric\n", + ")\n", + "print(\"Clean Value:\", clean_value)\n", + "print(\"Clean Activations Cached:\", len(clean_cache))\n", + "print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n", + "corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(\n", + " model, corrupted_tokens, ioi_metric\n", + ")\n", + "print(\"Corrupted Value:\", corrupted_value)\n", + "print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n", + "print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ### Attention Attribution\n", + " The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n", + " Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n", + " Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n", + " We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def create_attention_attr(\n", + " clean_cache, clean_grad_cache\n", + ") -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n", + " attention_stack = torch.stack(\n", + " [clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n", + " )\n", + " attention_grad_stack = torch.stack(\n", + " [clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n", + " )\n", + " attention_attr = attention_grad_stack * attention_stack\n", + " attention_attr = einops.rearrange(\n", + " attention_attr,\n", + " \"layer batch head_index dest src -> batch layer head_index dest src\",\n", + " )\n", + " return attention_attr\n", + "\n", + "\n", + "attention_attr = create_attention_attr(clean_cache, clean_grad_cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n", + "['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n", + "['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n" + ] + } + ], + "source": [ + "HEAD_NAMES = [\n", + " f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n", + "]\n", + "HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n", + "HEAD_NAMES_QKV = [\n", + " f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]\n", + "]\n", + "print(HEAD_NAMES[:5])\n", + "print(HEAD_NAMES_SIGNED[:5])\n", + "print(HEAD_NAMES_QKV[:5])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "### Attention Attribution for first sequence" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "### Summed Attention Attribution for all sequences" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n" + ] + } + ], + "source": [ + "def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n", + " if len(tokens.shape) == 2:\n", + " tokens = tokens[index]\n", + " if len(attention_attr.shape) == 5:\n", + " attention_attr = attention_attr[index]\n", + " attention_attr_pos = attention_attr.clamp(min=-1e-5)\n", + " attention_attr_neg = -attention_attr.clamp(max=1e-5)\n", + " attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n", + " attention_attr_signed = einops.rearrange(\n", + " attention_attr_signed,\n", + " \"sign layer head_index dest src -> (layer head_index sign) dest src\",\n", + " )\n", + " attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n", + " attention_attr_indices = (\n", + " attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n", + " )\n", + " # print(attention_attr_indices.shape)\n", + " # print(attention_attr_indices)\n", + " attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n", + " head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n", + "\n", + " if title:\n", + " display(Markdown(\"### \" + title))\n", + " display(\n", + " pysvelte.AttentionMulti(\n", + " tokens=model.to_str_tokens(tokens),\n", + " attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k],\n", + " head_labels=head_labels[:top_k],\n", + " )\n", + " )\n", + "\n", + "\n", + "plot_attention_attr(\n", + " attention_attr,\n", + " clean_tokens,\n", + " index=0,\n", + " title=\"Attention Attribution for first sequence\",\n", + ")\n", + "\n", + "plot_attention_attr(\n", + " attention_attr.sum(0),\n", + " clean_tokens[0],\n", + " title=\"Summed Attention Attribution for all sequences\",\n", + ")\n", + "print(\n", + " \"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Attribution Patching\n", + " In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n", + " ### Residual Stream Patching\n", + "
Note: We add up across both d_model and batch (Explanation).\n", + " We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n", + " We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_residual(\n", + " clean_cache: ActivationCache,\n", + " corrupted_cache: ActivationCache,\n", + " corrupted_grad_cache: ActivationCache,\n", + ") -> TT[\"component\", \"pos\"]:\n", + " clean_residual, residual_labels = clean_cache.accumulated_resid(\n", + " -1, incl_mid=True, return_labels=True\n", + " )\n", + " corrupted_residual = corrupted_cache.accumulated_resid(\n", + " -1, incl_mid=True, return_labels=False\n", + " )\n", + " corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(\n", + " -1, incl_mid=True, return_labels=False\n", + " )\n", + " residual_attr = einops.reduce(\n", + " corrupted_grad_residual * (clean_residual - corrupted_residual),\n", + " \"component batch pos d_model -> component pos\",\n", + " \"sum\",\n", + " )\n", + " return residual_attr, residual_labels\n", + "\n", + "\n", + "residual_attr, residual_labels = attr_patch_residual(\n", + " clean_cache, corrupted_cache, corrupted_grad_cache\n", + ")\n", + "imshow(\n", + " residual_attr,\n", + " y=residual_labels,\n", + " yaxis=\"Component\",\n", + " xaxis=\"Position\",\n", + " title=\"Residual Attribution Patching\",\n", + ")\n", + "\n", + "# ### Layer Output Patching" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_layer_out(\n", + " clean_cache: ActivationCache,\n", + " corrupted_cache: ActivationCache,\n", + " corrupted_grad_cache: ActivationCache,\n", + ") -> TT[\"component\", \"pos\"]:\n", + " clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n", + " corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n", + " corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(\n", + " -1, return_labels=False\n", + " )\n", + " layer_out_attr = einops.reduce(\n", + " corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n", + " \"component batch pos d_model -> component pos\",\n", + " \"sum\",\n", + " )\n", + " return layer_out_attr, labels\n", + "\n", + "\n", + "layer_out_attr, layer_out_labels = attr_patch_layer_out(\n", + " clean_cache, corrupted_cache, corrupted_grad_cache\n", + ")\n", + "imshow(\n", + " layer_out_attr,\n", + " y=layer_out_labels,\n", + " yaxis=\"Component\",\n", + " xaxis=\"Position\",\n", + " title=\"Layer Output Attribution Patching\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def attr_patch_head_out(\n", + " clean_cache: ActivationCache,\n", + " corrupted_cache: ActivationCache,\n", + " corrupted_grad_cache: ActivationCache,\n", + ") -> TT[\"component\", \"pos\"]:\n", + " labels = HEAD_NAMES\n", + "\n", + " clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n", + " corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n", + " corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(\n", + " -1, return_labels=False\n", + " )\n", + " head_out_attr = einops.reduce(\n", + " corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n", + " \"component batch pos d_model -> component pos\",\n", + " \"sum\",\n", + " )\n", + " return head_out_attr, labels\n", + "\n", + "\n", + "head_out_attr, head_out_labels = attr_patch_head_out(\n", + " clean_cache, corrupted_cache, corrupted_grad_cache\n", + ")\n", + "imshow(\n", + " head_out_attr,\n", + " y=head_out_labels,\n", + " yaxis=\"Component\",\n", + " xaxis=\"Position\",\n", + " title=\"Head Output Attribution Patching\",\n", + ")\n", + "sum_head_out_attr = einops.reduce(\n", + " head_out_attr,\n", + " \"(layer head) pos -> layer head\",\n", + " \"sum\",\n", + " layer=model.cfg.n_layers,\n", + " head=model.cfg.n_heads,\n", + ")\n", + "imshow(\n", + " sum_head_out_attr,\n", + " yaxis=\"Layer\",\n", + " xaxis=\"Head Index\",\n", + " title=\"Head Output Attribution Patching Sum Over Pos\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ### Head Activation Patching\n", + " Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n", + " As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n", + " We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "#### Key Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Query Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Value Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "#### Mixed Value Head Vector Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from typing_extensions import Literal\n", + "\n", + "\n", + "def stack_head_vector_from_cache(\n", + " cache, activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n", + ") -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n", + " \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n", + " stacked_head_vectors = torch.stack(\n", + " [cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0\n", + " )\n", + " stacked_head_vectors = einops.rearrange(\n", + " stacked_head_vectors,\n", + " \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\",\n", + " )\n", + " return stacked_head_vectors\n", + "\n", + "\n", + "def attr_patch_head_vector(\n", + " clean_cache: ActivationCache,\n", + " corrupted_cache: ActivationCache,\n", + " corrupted_grad_cache: ActivationCache,\n", + " activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n", + ") -> TT[\"component\", \"pos\"]:\n", + " labels = HEAD_NAMES\n", + "\n", + " clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n", + " corrupted_head_vector = stack_head_vector_from_cache(\n", + " corrupted_cache, activation_name\n", + " )\n", + " corrupted_grad_head_vector = stack_head_vector_from_cache(\n", + " corrupted_grad_cache, activation_name\n", + " )\n", + " head_vector_attr = einops.reduce(\n", + " corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n", + " \"component batch pos d_head -> component pos\",\n", + " \"sum\",\n", + " )\n", + " return head_vector_attr, labels\n", + "\n", + "\n", + "head_vector_attr_dict = {}\n", + "for activation_name, activation_name_full in [\n", + " (\"k\", \"Key\"),\n", + " (\"q\", \"Query\"),\n", + " (\"v\", \"Value\"),\n", + " (\"z\", \"Mixed Value\"),\n", + "]:\n", + " display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n", + " head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(\n", + " clean_cache, corrupted_cache, corrupted_grad_cache, activation_name\n", + " )\n", + " imshow(\n", + " head_vector_attr_dict[activation_name],\n", + " y=head_vector_labels,\n", + " yaxis=\"Component\",\n", + " xaxis=\"Position\",\n", + " title=f\"{activation_name_full} Attribution Patching\",\n", + " )\n", + " sum_head_vector_attr = einops.reduce(\n", + " head_vector_attr_dict[activation_name],\n", + " \"(layer head) pos -> layer head\",\n", + " \"sum\",\n", + " layer=model.cfg.n_layers,\n", + " head=model.cfg.n_heads,\n", + " )\n", + " imshow(\n", + " sum_head_vector_attr,\n", + " yaxis=\"Layer\",\n", + " xaxis=\"Head Index\",\n", + " title=f\"{activation_name_full} Attribution Patching Sum Over Pos\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "### Head Pattern Attribution Patching" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + "\n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from typing_extensions import Literal\n", + "\n", + "\n", + "def stack_head_pattern_from_cache(\n", + " cache,\n", + ") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n", + " \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n", + " stacked_head_pattern = torch.stack(\n", + " [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n", + " )\n", + " stacked_head_pattern = einops.rearrange(\n", + " stacked_head_pattern,\n", + " \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n", + " )\n", + " return stacked_head_pattern\n", + "\n", + "\n", + "def attr_patch_head_pattern(\n", + " clean_cache: ActivationCache,\n", + " corrupted_cache: ActivationCache,\n", + " corrupted_grad_cache: ActivationCache,\n", + ") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n", + " labels = HEAD_NAMES\n", + "\n", + " clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n", + " corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n", + " corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n", + " head_pattern_attr = einops.reduce(\n", + " corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n", + " \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n", + " \"sum\",\n", + " )\n", + " return head_pattern_attr, labels\n", + "\n", + "\n", + "head_pattern_attr, labels = attr_patch_head_pattern(\n", + " clean_cache, corrupted_cache, corrupted_grad_cache\n", + ")\n", + "\n", + "plot_attention_attr(\n", + " einops.rearrange(\n", + " head_pattern_attr,\n", + " \"(layer head) dest src -> layer head dest src\",\n", + " layer=model.cfg.n_layers,\n", + " head=model.cfg.n_heads,\n", + " ),\n", + " clean_tokens,\n", + " index=0,\n", + " title=\"Head Pattern Attribution Patching\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_head_vector_grad_input_from_grad_cache(\n", + " grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n", + ") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n", + " vector_grad = grad_cache[activation_name, layer]\n", + " ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n", + " attn_layer_object = model.blocks[layer].attn\n", + " if activation_name == \"q\":\n", + " W = attn_layer_object.W_Q\n", + " elif activation_name == \"k\":\n", + " W = attn_layer_object.W_K\n", + " elif activation_name == \"v\":\n", + " W = attn_layer_object.W_V\n", + " else:\n", + " raise ValueError(\"Invalid activation name\")\n", + "\n", + " return einsum(\n", + " \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n", + " vector_grad,\n", + " ln_scales.squeeze(-1),\n", + " W,\n", + " )\n", + "\n", + "\n", + "def get_stacked_head_vector_grad_input(\n", + " grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n", + ") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n", + " return torch.stack(\n", + " [\n", + " get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n", + " for l in range(model.cfg.n_layers)\n", + " ],\n", + " dim=0,\n", + " )\n", + "\n", + "\n", + "def get_full_vector_grad_input(\n", + " grad_cache,\n", + ") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n", + " return torch.stack(\n", + " [\n", + " get_stacked_head_vector_grad_input(grad_cache, activation_name)\n", + " for activation_name in [\"q\", \"k\", \"v\"]\n", + " ],\n", + " dim=0,\n", + " )\n", + "\n", + "\n", + "def attr_patch_head_path(\n", + " clean_cache: ActivationCache,\n", + " corrupted_cache: ActivationCache,\n", + " corrupted_grad_cache: ActivationCache,\n", + ") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n", + " \"\"\"\n", + " Computes the attribution patch along the path between each pair of heads.\n", + "\n", + " Sets this to zero for the path from any late head to any early head\n", + "\n", + " \"\"\"\n", + " start_labels = HEAD_NAMES\n", + " end_labels = HEAD_NAMES_QKV\n", + " full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n", + " clean_head_result_stack = clean_cache.stack_head_results(-1)\n", + " corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n", + " diff_head_result = einops.rearrange(\n", + " clean_head_result_stack - corrupted_head_result_stack,\n", + " \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n", + " layer=model.cfg.n_layers,\n", + " head_index=model.cfg.n_heads,\n", + " )\n", + " path_attr = einsum(\n", + " \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n", + " full_vector_grad_input,\n", + " diff_head_result,\n", + " )\n", + " correct_layer_order_mask = (\n", + " torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n", + " > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n", + " ).to(path_attr.device)\n", + " zero = torch.zeros(1, device=path_attr.device)\n", + " path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n", + "\n", + " path_attr = einops.rearrange(\n", + " path_attr,\n", + " \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n", + " )\n", + " return path_attr, end_labels, start_labels\n", + "\n", + "\n", + "head_path_attr, end_labels, start_labels = attr_patch_head_path(\n", + " clean_cache, corrupted_cache, corrupted_grad_cache\n", + ")\n", + "imshow(\n", + " head_path_attr.sum(-1),\n", + " y=end_labels,\n", + " yaxis=\"Path End (Head Input)\",\n", + " x=start_labels,\n", + " xaxis=\"Path Start (Head Output)\",\n", + " title=\"Head Path Attribution Patching\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n", + "line(head_out_values)\n", + "top_head_indices = head_out_indices[:22].sort().values\n", + "top_end_indices = []\n", + "top_end_labels = []\n", + "top_start_indices = []\n", + "top_start_labels = []\n", + "for i in top_head_indices:\n", + " i = i.item()\n", + " top_start_indices.append(i)\n", + " top_start_labels.append(start_labels[i])\n", + " for j in range(3):\n", + " top_end_indices.append(3 * i + j)\n", + " top_end_labels.append(end_labels[3 * i + j])\n", + "\n", + "imshow(\n", + " head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n", + " y=top_end_labels,\n", + " yaxis=\"Path End (Head Input)\",\n", + " x=top_start_labels,\n", + " xaxis=\"Path Start (Head Output)\",\n", + " title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n", + " imshow(\n", + " head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n", + " y=top_end_labels[j::3],\n", + " yaxis=\"Path End (Head Input)\",\n", + " x=top_start_labels,\n", + " xaxis=\"Path Start (Head Output)\",\n", + " title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "top_head_path_attr = einops.rearrange(\n", + " head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n", + " \"(head_end qkv) head_start -> qkv head_end head_start\",\n", + " qkv=3,\n", + ")\n", + "imshow(\n", + " top_head_path_attr,\n", + " y=[i[:-1] for i in top_end_labels[::3]],\n", + " yaxis=\"Path End (Head Input)\",\n", + " x=top_start_labels,\n", + " xaxis=\"Path Start (Head Output)\",\n", + " title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n", + " facet_col=0,\n", + " facet_labels=[\"Query\", \"Key\", \"Value\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "interesting_heads = [\n", + " 5 * model.cfg.n_heads + 5,\n", + " 8 * model.cfg.n_heads + 6,\n", + " 9 * model.cfg.n_heads + 9,\n", + "]\n", + "interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n", + "for head_index, label in zip(interesting_heads, interesting_head_labels):\n", + " in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n", + " out_paths = head_path_attr[:, head_index].sum(-1)\n", + " out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n", + " all_paths = torch.cat([in_paths, out_paths], dim=0)\n", + " all_paths = einops.rearrange(\n", + " all_paths,\n", + " \"path_type (layer head) -> path_type layer head\",\n", + " layer=model.cfg.n_layers,\n", + " head=model.cfg.n_heads,\n", + " )\n", + " imshow(\n", + " all_paths,\n", + " facet_col=0,\n", + " facet_labels=[\n", + " \"Query (In)\",\n", + " \"Key (In)\",\n", + " \"Value (In)\",\n", + " \"Query (Out)\",\n", + " \"Key (Out)\",\n", + " \"Value (Out)\",\n", + " ],\n", + " title=f\"Input and Output Paths for head {label}\",\n", + " yaxis=\"Layer\",\n", + " xaxis=\"Head\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Validating Attribution vs Activation Patching\n", + " Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n", + " My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n", + " See more discussion in the accompanying blog post!\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "attribution_cache_dict = {}\n", + "for key in corrupted_grad_cache.cache_dict.keys():\n", + " attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n", + " clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n", + " )\n", + "attr_cache = ActivationCache(attribution_cache_dict, model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " By block: For each head we patch the starting residual stream, attention output + MLP output" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "str_tokens = model.to_str_tokens(clean_tokens[0])\n", + "context_length = len(str_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "95a5290e11b64b6a95ef5dd37d027c7a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/180 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "every_block_act_patch_result = patching.get_act_patch_block_every(\n", + " model, corrupted_tokens, clean_cache, ioi_metric\n", + ")\n", + "imshow(\n", + " every_block_act_patch_result,\n", + " facet_col=0,\n", + " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n", + " title=\"Activation Patching Per Block\",\n", + " xaxis=\"Position\",\n", + " yaxis=\"Layer\",\n", + " zmax=1,\n", + " zmin=-1,\n", + " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_attr_patch_block_every(attr_cache):\n", + " resid_pre_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"resid_pre\"),\n", + " \"layer batch pos d_model -> layer pos\",\n", + " \"sum\",\n", + " )\n", + " attn_out_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"attn_out\"),\n", + " \"layer batch pos d_model -> layer pos\",\n", + " \"sum\",\n", + " )\n", + " mlp_out_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"mlp_out\"),\n", + " \"layer batch pos d_model -> layer pos\",\n", + " \"sum\",\n", + " )\n", + "\n", + " every_block_attr_patch_result = torch.stack(\n", + " [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n", + " )\n", + " return every_block_attr_patch_result\n", + "\n", + "\n", + "every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n", + "imshow(\n", + " every_block_attr_patch_result,\n", + " facet_col=0,\n", + " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n", + " title=\"Attribution Patching Per Block\",\n", + " xaxis=\"Position\",\n", + " yaxis=\"Layer\",\n", + " zmax=1,\n", + " zmin=-1,\n", + " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "scatter(\n", + " y=every_block_attr_patch_result.reshape(3, -1),\n", + " x=every_block_act_patch_result.reshape(3, -1),\n", + " facet_col=0,\n", + " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n", + " title=\"Attribution vs Activation Patching Per Block\",\n", + " xaxis=\"Activation Patch\",\n", + " yaxis=\"Attribution Patch\",\n", + " hover=[\n", + " f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n", + " for l in range(model.cfg.n_layers)\n", + " for p in range(context_length)\n", + " ],\n", + " color=einops.repeat(\n", + " torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n", + " ),\n", + " color_continuous_scale=\"Portland\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "18b2e6b0985b40cd8c0cd1a16ba62975", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n", + " model, corrupted_tokens, clean_cache, ioi_metric\n", + ")\n", + "imshow(\n", + " every_head_all_pos_act_patch_result,\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Activation Patching Per Head (All Pos)\",\n", + " xaxis=\"Head\",\n", + " yaxis=\"Layer\",\n", + " zmax=1,\n", + " zmin=-1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_attr_patch_attn_head_all_pos_every(attr_cache):\n", + " head_out_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"z\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_q_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"q\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_k_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"k\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_v_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"v\"),\n", + " \"layer batch pos head_index d_head -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + " head_pattern_all_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"pattern\"),\n", + " \"layer batch head_index dest_pos src_pos -> layer head_index\",\n", + " \"sum\",\n", + " )\n", + "\n", + " return torch.stack(\n", + " [\n", + " head_out_all_pos_attr,\n", + " head_q_all_pos_attr,\n", + " head_k_all_pos_attr,\n", + " head_v_all_pos_attr,\n", + " head_pattern_all_pos_attr,\n", + " ]\n", + " )\n", + "\n", + "\n", + "every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n", + " attr_cache\n", + ")\n", + "imshow(\n", + " every_head_all_pos_attr_patch_result,\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Attribution Patching Per Head (All Pos)\",\n", + " xaxis=\"Head\",\n", + " yaxis=\"Layer\",\n", + " zmax=1,\n", + " zmin=-1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "scatter(\n", + " y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n", + " x=every_head_all_pos_act_patch_result.reshape(5, -1),\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n", + " xaxis=\"Activation Patch\",\n", + " yaxis=\"Attribution Patch\",\n", + " include_diag=True,\n", + " hover=head_out_labels,\n", + " color=einops.repeat(\n", + " torch.arange(model.cfg.n_layers),\n", + " \"layer -> (layer head)\",\n", + " head=model.cfg.n_heads,\n", + " ),\n", + " color_continuous_scale=\"Portland\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n", + " Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "graph_tok_labels = [\n", + " f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n", + "]\n", + "imshow(\n", + " clean_cache[\"pattern\", 5][:, 5],\n", + " x=graph_tok_labels,\n", + " y=graph_tok_labels,\n", + " facet_col=0,\n", + " title=\"Attention for Head L5H5\",\n", + " facet_name=\"Prompt\",\n", + ")\n", + "imshow(\n", + " clean_cache[\"pattern\", 10][:, 7],\n", + " x=graph_tok_labels,\n", + " y=graph_tok_labels,\n", + " facet_col=0,\n", + " title=\"Attention for Head L10H7\",\n", + " facet_name=\"Prompt\",\n", + ")\n", + "imshow(\n", + " clean_cache[\"pattern\", 11][:, 10],\n", + " x=graph_tok_labels,\n", + " y=graph_tok_labels,\n", + " facet_col=0,\n", + " title=\"Attention for Head L11H10\",\n", + " facet_name=\"Prompt\",\n", + ")\n", + "\n", + "\n", + "# [markdown]" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "06f39489001845849fbc7446a07066f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/2160 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n", + " model, corrupted_tokens, clean_cache, ioi_metric\n", + ")\n", + "every_head_by_pos_act_patch_result = einops.rearrange(\n", + " every_head_by_pos_act_patch_result,\n", + " \"act_type layer pos head -> act_type (layer head) pos\",\n", + ")\n", + "imshow(\n", + " every_head_by_pos_act_patch_result,\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Activation Patching Per Head (By Pos)\",\n", + " xaxis=\"Position\",\n", + " yaxis=\"Layer & Head\",\n", + " zmax=1,\n", + " zmin=-1,\n", + " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n", + " y=head_out_labels,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def get_attr_patch_attn_head_by_pos_every(attr_cache):\n", + " head_out_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"z\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_q_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"q\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_k_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"k\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_v_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"v\"),\n", + " \"layer batch pos head_index d_head -> layer pos head_index\",\n", + " \"sum\",\n", + " )\n", + " head_pattern_by_pos_attr = einops.reduce(\n", + " attr_cache.stack_activation(\"pattern\"),\n", + " \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n", + " \"sum\",\n", + " )\n", + "\n", + " return torch.stack(\n", + " [\n", + " head_out_by_pos_attr,\n", + " head_q_by_pos_attr,\n", + " head_k_by_pos_attr,\n", + " head_v_by_pos_attr,\n", + " head_pattern_by_pos_attr,\n", + " ]\n", + " )\n", + "\n", + "\n", + "every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n", + "every_head_by_pos_attr_patch_result = einops.rearrange(\n", + " every_head_by_pos_attr_patch_result,\n", + " \"act_type layer pos head -> act_type (layer head) pos\",\n", + ")\n", + "imshow(\n", + " every_head_by_pos_attr_patch_result,\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Attribution Patching Per Head (By Pos)\",\n", + " xaxis=\"Position\",\n", + " yaxis=\"Layer & Head\",\n", + " zmax=1,\n", + " zmin=-1,\n", + " x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n", + " y=head_out_labels,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "scatter(\n", + " y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n", + " x=every_head_by_pos_act_patch_result.reshape(5, -1),\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n", + " xaxis=\"Activation Patch\",\n", + " yaxis=\"Attribution Patch\",\n", + " include_diag=True,\n", + " hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n", + " color=einops.repeat(\n", + " torch.arange(model.cfg.n_layers),\n", + " \"layer -> (layer head pos)\",\n", + " head=model.cfg.n_heads,\n", + " pos=15,\n", + " ),\n", + " color_continuous_scale=\"Portland\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## Factual Knowledge Patching Example\n", + " Incomplete, but maybe of interest!\n", + " Note that I have better results with the corrupted prompt as having random words rather than Colosseum." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using pad_token, but it is not set yet.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gpt2-xl into HookedTransformer\n", + "Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n", + "Tokenized answer: [' Paris']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n", + "Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n", + "Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n", + "Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n", + "Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n", + "Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n", + "Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n", + "Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n", + "Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n", + "Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Paris', 0)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n", + "Tokenized answer: [' Rome']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n", + "Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n", + "Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n", + "Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n", + "Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n", + "Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n", + "Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n", + "Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n", + "Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n", + "Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Rome', 0)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "gpt2_xl = TransformerBridge.boot_transformers(\"gpt2-xl\")\n", + "gpt2_xl.enable_compatibility_mode()\n", + "clean_prompt = \"The Eiffel Tower is located in the city of\"\n", + "clean_answer = \" Paris\"\n", + "# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n", + "corrupted_prompt = \"The Colosseum is located in the city of\"\n", + "corrupted_answer = \" Rome\"\n", + "utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n", + "utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n", + "corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n", + "\n", + "\n", + "def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n", + " return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean logit diff: 10.634519577026367\n", + "Corrupted logit diff: -8.988396644592285\n", + "Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n", + "Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n", + "CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n", + "corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n", + "CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n", + "\n", + "\n", + "def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n", + " return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n", + " CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n", + " )\n", + "\n", + "\n", + "print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n", + "print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n", + "print(\"Clean Metric:\", factual_metric(clean_logits))\n", + "print(\"Corrupted Metric:\", factual_metric(corrupted_logits))" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n", + "Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n" + ] + } + ], + "source": [ + "clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n", + "clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n", + "corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n", + "corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n", + "print(\"Clean:\", clean_str_tokens)\n", + "print(\"Corrupted:\", corrupted_str_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b767eef7a3cd49b9b3cb6e5301463f08", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/48 [00:00\n", + "\n", + "\n", + "
\n", + "
\n", + "\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def act_patch_residual(clean_cache, corrupted_tokens, model: TransformerBridge, metric):\n", + " if len(corrupted_tokens.shape) == 2:\n", + " corrupted_tokens = corrupted_tokens[0]\n", + " residual_patches = torch.zeros(\n", + " (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n", + " )\n", + "\n", + " def residual_hook(resid_pre, hook, layer, pos):\n", + " resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n", + " return resid_pre\n", + "\n", + " for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n", + " for pos in range(len(corrupted_tokens)):\n", + " patched_logits = model.run_with_hooks(\n", + " corrupted_tokens,\n", + " fwd_hooks=[\n", + " (\n", + " f\"blocks.{layer}.hook_resid_pre\",\n", + " partial(residual_hook, layer=layer, pos=pos),\n", + " )\n", + " ],\n", + " )\n", + " residual_patches[layer, pos] = metric(patched_logits).item()\n", + " return residual_patches\n", + "\n", + "\n", + "residual_act_patch = act_patch_residual(\n", + " clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n", + ")\n", + "\n", + "imshow(\n", + " residual_act_patch,\n", + " title=\"Factual Recall Patching (Residual)\",\n", + " xaxis=\"Position\",\n", + " yaxis=\"Layer\",\n", + " x=clean_str_tokens,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 6b0831dba9e42786efbf8146f97305c4f1ae3556 Mon Sep 17 00:00:00 2001 From: degenfabian Date: Mon, 18 Aug 2025 19:08:27 +0200 Subject: [PATCH 2/7] updated loading in bert demo to use transformer bridge --- demos/BERT.ipynb | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index e420b5e0d..192469ba9 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -157,7 +157,7 @@ "\n", "from transformers import AutoTokenizer\n", "\n", - "from transformer_lens import HookedEncoder, BertNextSentencePrediction" + "from transformer_lens.model_bridge import TransformerBridge" ] }, { @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -214,7 +214,8 @@ ], "source": [ "# NBVAL_IGNORE_OUTPUT\n", - "bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n", + "bert = TransformerBridge.boot_transformers(\"bert-base-cased\")\n", + "bert.enable_compatibility_mode()\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" ] }, @@ -287,14 +288,13 @@ "metadata": {}, "source": [ "## Next Sentence Prediction\n", - "To carry out Next Sentence Prediction, you have to use the class BertNextSentencePrediction, and pass a HookedEncoder in its constructor. \n", - "Then, create a list with the two sentences you want to perform NSP on as elements and use that as input to the forward function. \n", + "To carry out Next Sentence Prediction create a list with the two sentences you want to perform NSP on as elements and use that as input to the forward function. \n", "The model will then predict the probability of the sentence at position 1 following (i.e. being the next sentence) to the sentence at position 0." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -308,13 +308,12 @@ } ], "source": [ - "nsp = BertNextSentencePrediction(bert)\n", "sentence_a = \"A man walked into a grocery store.\"\n", "sentence_b = \"He bought an apple.\"\n", "\n", "input = [sentence_a, sentence_b]\n", "\n", - "predictions = nsp(input, return_type=\"predictions\")\n", + "predictions = bert(input, return_type=\"predictions\")\n", "\n", "print(f\"Sentence A: {sentence_a}\")\n", "print(f\"Sentence B: {sentence_b}\")\n", From 93b73f73b26a647ae94be6ab89675c9d23baf802 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 4 Mar 2026 11:38:00 -0600 Subject: [PATCH 3/7] Update to allow NSP via bridge --- demos/BERT.ipynb | 160 +++++++++++++----- transformer_lens/model_bridge/bridge.py | 31 +++- .../model_bridge/sources/transformers.py | 7 +- .../supported_architectures/bert.py | 19 ++- 4 files changed, 170 insertions(+), 47 deletions(-) diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 0108a5154..9338fd30e 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -28,16 +28,66 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, - "outputs": [], - "source": "# NBVAL_IGNORE_OUTPUT\nimport os\n\n# Janky code to do different setup when run in a Colab notebook vs VSCode\nDEVELOPMENT_MODE = False\nIN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\ntry:\n import google.colab\n\n IN_COLAB = True\n print(\"Running as a Colab notebook\")\n\n # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n # # Install another version of node that makes PySvelte work way faster\n # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n # %pip install git+https://github.com/neelnanda-io/PySvelte.git\nexcept:\n IN_COLAB = False\n\nif not IN_GITHUB and not IN_COLAB:\n print(\"Running as a Jupyter notebook - intended for development only!\")\n from IPython import get_ipython\n\n ipython = get_ipython()\n # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n ipython.run_line_magic(\"load_ext\", \"autoreload\")\n ipython.run_line_magic(\"autoreload\", \"2\")\n\nif IN_COLAB:\n %pip install transformer_lens\n %pip install circuitsvis" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n", + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "\n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "except:\n", + " IN_COLAB = False\n", + "\n", + "if not IN_GITHUB and not IN_COLAB:\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using renderer: colab\n" + ] + } + ], "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", @@ -51,27 +101,27 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -85,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -99,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -119,31 +169,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", - "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Moving model to device: mps\n", - "Loaded pretrained model bert-base-cased into HookedTransformer\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1d4b75dcfcbf488da7196992cde5c9bb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading weights: 0%| | 0/202 [00:00 "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -160,6 +161,8 @@ def boot_transformers( tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. trust_remote_code: Whether to trust remote code for custom model architectures. + model_class: Optional HuggingFace model class to use instead of the default + auto-detected class (e.g., BertForNextSentencePrediction). Returns: The bridge to the loaded model. @@ -174,6 +177,7 @@ def boot_transformers( tokenizer=tokenizer, load_weights=load_weights, trust_remote_code=trust_remote_code, + model_class=model_class, ) @property @@ -1206,7 +1210,7 @@ def forward( Args: input: Input to the model - return_type: Type of output to return ('logits', 'loss', 'both', None) + return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None) loss_per_token: Whether to return loss per token prepend_bos: Whether to prepend BOS token padding_side: Which side to pad on @@ -1341,6 +1345,31 @@ def forward( ), f"Expected logits tensor, got {type(logits)}" loss = self.loss_fn(logits, input_ids, per_token=loss_per_token) return (logits, loss) + elif return_type == "predictions": + assert ( + self.tokenizer is not None + ), "Must have a tokenizer to use return_type='predictions'" + if logits.shape[-1] == 2: + # Next Sentence Prediction — 2-class output + logprobs = logits.log_softmax(dim=-1) + predictions = [ + "The sentences are sequential", + "The sentences are NOT sequential", + ] + return predictions[logprobs.argmax(dim=-1).item()] + else: + # Masked Language Modeling — decode [MASK] tokens + logprobs = logits[ + input_ids == self.tokenizer.mask_token_id + ].log_softmax(dim=-1) + predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) + if " " in predictions: + predictions = predictions.split(" ") + predictions = [ + f"Prediction {i}: {p}" + for i, p in enumerate(predictions) + ] + return predictions elif return_type is None: return None else: diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 3449d0e1e..b9a4cfd28 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -246,6 +246,7 @@ def boot( tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, trust_remote_code: bool = False, + model_class: type | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. @@ -256,6 +257,9 @@ def boot( dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. + model_class: Optional HuggingFace model class to use instead of the default auto-detected + class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding + adapter is selected automatically (e.g., BertForNextSentencePrediction). Returns: The bridge to the loaded model. @@ -301,7 +305,8 @@ def boot( if device is None: device = get_device() adapter.cfg.device = str(device) - model_class = get_hf_model_class_for_architecture(architecture) + if model_class is None: + model_class = get_hf_model_class_for_architecture(architecture) # Ensure pad_token_id exists on HF config. Transformers v5 raises AttributeError # for missing config attributes (instead of returning None), which crashes models # like Phi-1 that access config.pad_token_id during __init__. diff --git a/transformer_lens/model_bridge/supported_architectures/bert.py b/transformer_lens/model_bridge/supported_architectures/bert.py index 634e2975b..0c92c6993 100644 --- a/transformer_lens/model_bridge/supported_architectures/bert.py +++ b/transformer_lens/model_bridge/supported_architectures/bert.py @@ -82,8 +82,9 @@ def __init__(self, cfg: Any) -> None: } # Set up component mapping - # The bridge loads BertForMaskedLM, so core model paths need the 'bert.' prefix. - # The MLM head (cls.predictions) is at the top level of BertForMaskedLM. + # Core model paths use the 'bert.' prefix. The head components (unembed, + # ln_final) are set to MLM defaults here and adjusted in prepare_model() + # if the actual HF model is a different task variant (e.g., NSP). self.component_mapping = { "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"), "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"), @@ -125,3 +126,17 @@ def __init__(self, cfg: Any) -> None: name="cls.predictions.transform.LayerNorm", config=self.cfg ), } + + def prepare_model(self, hf_model: Any) -> None: + """Adjust component mapping based on the actual HF model variant. + + BertForMaskedLM has cls.predictions (MLM head). + BertForNextSentencePrediction has cls.seq_relationship (NSP head) + and no MLM-specific LayerNorm. + """ + if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"): + # NSP model — swap head components + self.component_mapping["unembed"] = UnembeddingBridge( + name="cls.seq_relationship" + ) + self.component_mapping.pop("ln_final", None) From 1adb40ec54c4f509ac037345dec7f33644ca26ab Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 4 Mar 2026 11:42:42 -0600 Subject: [PATCH 4/7] Format and type fixes --- transformer_lens/model_bridge/bridge.py | 9 ++------- transformer_lens/model_bridge/sources/transformers.py | 2 +- .../model_bridge/supported_architectures/bert.py | 5 ++--- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 5ed5dd383..95efce536 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1359,16 +1359,11 @@ def forward( return predictions[logprobs.argmax(dim=-1).item()] else: # Masked Language Modeling — decode [MASK] tokens - logprobs = logits[ - input_ids == self.tokenizer.mask_token_id - ].log_softmax(dim=-1) + logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1) predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) if " " in predictions: predictions = predictions.split(" ") - predictions = [ - f"Prediction {i}: {p}" - for i, p in enumerate(predictions) - ] + predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] return predictions elif return_type is None: return None diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index b9a4cfd28..ad77e0783 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -246,7 +246,7 @@ def boot( tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, trust_remote_code: bool = False, - model_class: type | None = None, + model_class: Any | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. diff --git a/transformer_lens/model_bridge/supported_architectures/bert.py b/transformer_lens/model_bridge/supported_architectures/bert.py index 0c92c6993..265bfb531 100644 --- a/transformer_lens/model_bridge/supported_architectures/bert.py +++ b/transformer_lens/model_bridge/supported_architectures/bert.py @@ -136,7 +136,6 @@ def prepare_model(self, hf_model: Any) -> None: """ if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"): # NSP model — swap head components - self.component_mapping["unembed"] = UnembeddingBridge( - name="cls.seq_relationship" - ) + assert self.component_mapping is not None + self.component_mapping["unembed"] = UnembeddingBridge(name="cls.seq_relationship") self.component_mapping.pop("ln_final", None) From 33043965b0352f0934f301c1e42e27f073372a38 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 4 Mar 2026 11:48:21 -0600 Subject: [PATCH 5/7] Add import --- transformer_lens/model_bridge/sources/transformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index ad77e0783..bb90121ab 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -7,6 +7,7 @@ import logging import os import warnings +from typing import Any import torch from transformers import ( From f5e12a00ced0aae468d93018dff315482c71854b Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 4 Mar 2026 12:37:45 -0600 Subject: [PATCH 6/7] Attribution Patching moved to own branch --- demos/Attribution_Patching_Demo.ipynb | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index 577ac7fc5..bd9b6c707 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -161,16 +161,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import transformer_lens\n", "import transformer_lens.utils as utils\n", + "from transformer_lens.hook_points import (\n", + " HookedRootModule,\n", + " HookPoint,\n", + ") # Hooking utilities\n", "from transformer_lens import (\n", + " HookedTransformer,\n", + " HookedTransformerConfig,\n", + " FactoredMatrix,\n", " ActivationCache,\n", - ")\n", - "from transformer_lens.model_bridge import TransformerBridge" + ")" ] }, { @@ -208,7 +214,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -227,8 +233,7 @@ } ], "source": [ - "model = TransformerBridge.boot_transformers(\"gpt2\")\n", - "model.enable_compatibility_mode()\n", + "model = HookedTransformer.from_pretrained(\"gpt2-small\")\n", "model.set_use_attn_result(True)" ] }, @@ -3421,7 +3426,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -3538,8 +3543,7 @@ } ], "source": [ - "gpt2_xl = TransformerBridge.boot_transformers(\"gpt2-xl\")\n", - "gpt2_xl.enable_compatibility_mode()\n", + "gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n", "clean_prompt = \"The Eiffel Tower is located in the city of\"\n", "clean_answer = \" Paris\"\n", "# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n", @@ -3632,7 +3636,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": {}, "outputs": [ { @@ -3690,7 +3694,7 @@ } ], "source": [ - "def act_patch_residual(clean_cache, corrupted_tokens, model: TransformerBridge, metric):\n", + "def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n", " if len(corrupted_tokens.shape) == 2:\n", " corrupted_tokens = corrupted_tokens[0]\n", " residual_patches = torch.zeros(\n", From f3855a4d22eda192d3b85632d438e62a64dd8638 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 4 Mar 2026 15:20:44 -0600 Subject: [PATCH 7/7] Hiding Attribution patching until its own PR --- .github/workflows/checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 72fcb403d..1b32e1a1a 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -224,7 +224,7 @@ jobs: matrix: notebook: # - "Activation_Patching_in_TL_Demo" - - "Attribution_Patching_Demo" + # - "Attribution_Patching_Demo" - "ARENA_Content" - "BERT" - "Exploratory_Analysis_Demo"