diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index a46b49976..9338fd30e 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -28,16 +28,66 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 39, "metadata": {}, - "outputs": [], - "source": "# NBVAL_IGNORE_OUTPUT\nimport os\n\n# Janky code to do different setup when run in a Colab notebook vs VSCode\nDEVELOPMENT_MODE = False\nIN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\ntry:\n import google.colab\n\n IN_COLAB = True\n print(\"Running as a Colab notebook\")\n\n # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n # # Install another version of node that makes PySvelte work way faster\n # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n # %pip install git+https://github.com/neelnanda-io/PySvelte.git\nexcept:\n IN_COLAB = False\n\nif not IN_GITHUB and not IN_COLAB:\n print(\"Running as a Jupyter notebook - intended for development only!\")\n from IPython import get_ipython\n\n ipython = get_ipython()\n # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n ipython.run_line_magic(\"load_ext\", \"autoreload\")\n ipython.run_line_magic(\"autoreload\", \"2\")\n\nif IN_COLAB:\n %pip install transformer_lens\n %pip install circuitsvis" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n", + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "import os\n", + "\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "\n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "except:\n", + " IN_COLAB = False\n", + "\n", + "if not IN_GITHUB and not IN_COLAB:\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", + "\n", + "if IN_COLAB:\n", + " %pip install transformer_lens\n", + " %pip install circuitsvis" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using renderer: colab\n" + ] + } + ], "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", @@ -51,27 +101,27 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
\n", + "
\n", " " ], "text/plain": [ - "" + "" ] }, - "execution_count": 3, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -85,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -94,12 +144,12 @@ "\n", "from transformers import AutoTokenizer\n", "\n", - "from transformer_lens import HookedEncoder, BertNextSentencePrediction" + "from transformer_lens.model_bridge import TransformerBridge" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -119,30 +169,29 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 44, "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Support for BERT in TransformerLens is currently experimental, until such a time when it has feature parity with HookedTransformer and has been tested on real research tasks. Until then, backward compatibility is not guaranteed. Please see the docs for information on the limitations of the current implementation.\n", - "If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Moving model to device: mps\n", - "Loaded pretrained model bert-base-cased into HookedTransformer\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1d4b75dcfcbf488da7196992cde5c9bb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading weights: 0%| | 0/202 [00:00 "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -160,6 +161,8 @@ def boot_transformers( tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. trust_remote_code: Whether to trust remote code for custom model architectures. + model_class: Optional HuggingFace model class to use instead of the default + auto-detected class (e.g., BertForNextSentencePrediction). Returns: The bridge to the loaded model. @@ -174,6 +177,7 @@ def boot_transformers( tokenizer=tokenizer, load_weights=load_weights, trust_remote_code=trust_remote_code, + model_class=model_class, ) @property @@ -1206,7 +1210,7 @@ def forward( Args: input: Input to the model - return_type: Type of output to return ('logits', 'loss', 'both', None) + return_type: Type of output to return ('logits', 'loss', 'both', 'predictions', None) loss_per_token: Whether to return loss per token prepend_bos: Whether to prepend BOS token padding_side: Which side to pad on @@ -1341,6 +1345,26 @@ def forward( ), f"Expected logits tensor, got {type(logits)}" loss = self.loss_fn(logits, input_ids, per_token=loss_per_token) return (logits, loss) + elif return_type == "predictions": + assert ( + self.tokenizer is not None + ), "Must have a tokenizer to use return_type='predictions'" + if logits.shape[-1] == 2: + # Next Sentence Prediction — 2-class output + logprobs = logits.log_softmax(dim=-1) + predictions = [ + "The sentences are sequential", + "The sentences are NOT sequential", + ] + return predictions[logprobs.argmax(dim=-1).item()] + else: + # Masked Language Modeling — decode [MASK] tokens + logprobs = logits[input_ids == self.tokenizer.mask_token_id].log_softmax(dim=-1) + predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) + if " " in predictions: + predictions = predictions.split(" ") + predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)] + return predictions elif return_type is None: return None else: diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 3449d0e1e..bb90121ab 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -7,6 +7,7 @@ import logging import os import warnings +from typing import Any import torch from transformers import ( @@ -246,6 +247,7 @@ def boot( tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, trust_remote_code: bool = False, + model_class: Any | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. @@ -256,6 +258,9 @@ def boot( dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. + model_class: Optional HuggingFace model class to use instead of the default auto-detected + class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding + adapter is selected automatically (e.g., BertForNextSentencePrediction). Returns: The bridge to the loaded model. @@ -301,7 +306,8 @@ def boot( if device is None: device = get_device() adapter.cfg.device = str(device) - model_class = get_hf_model_class_for_architecture(architecture) + if model_class is None: + model_class = get_hf_model_class_for_architecture(architecture) # Ensure pad_token_id exists on HF config. Transformers v5 raises AttributeError # for missing config attributes (instead of returning None), which crashes models # like Phi-1 that access config.pad_token_id during __init__. diff --git a/transformer_lens/model_bridge/supported_architectures/bert.py b/transformer_lens/model_bridge/supported_architectures/bert.py index 634e2975b..265bfb531 100644 --- a/transformer_lens/model_bridge/supported_architectures/bert.py +++ b/transformer_lens/model_bridge/supported_architectures/bert.py @@ -82,8 +82,9 @@ def __init__(self, cfg: Any) -> None: } # Set up component mapping - # The bridge loads BertForMaskedLM, so core model paths need the 'bert.' prefix. - # The MLM head (cls.predictions) is at the top level of BertForMaskedLM. + # Core model paths use the 'bert.' prefix. The head components (unembed, + # ln_final) are set to MLM defaults here and adjusted in prepare_model() + # if the actual HF model is a different task variant (e.g., NSP). self.component_mapping = { "embed": EmbeddingBridge(name="bert.embeddings.word_embeddings"), "pos_embed": PosEmbedBridge(name="bert.embeddings.position_embeddings"), @@ -125,3 +126,16 @@ def __init__(self, cfg: Any) -> None: name="cls.predictions.transform.LayerNorm", config=self.cfg ), } + + def prepare_model(self, hf_model: Any) -> None: + """Adjust component mapping based on the actual HF model variant. + + BertForMaskedLM has cls.predictions (MLM head). + BertForNextSentencePrediction has cls.seq_relationship (NSP head) + and no MLM-specific LayerNorm. + """ + if hasattr(hf_model, "cls") and hasattr(hf_model.cls, "seq_relationship"): + # NSP model — swap head components + assert self.component_mapping is not None + self.component_mapping["unembed"] = UnembeddingBridge(name="cls.seq_relationship") + self.component_mapping.pop("ln_final", None)