Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/speculators/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def attach_verifier(
self,
verifier: str | os.PathLike | PreTrainedModel,
mode: Literal["full", "train_only"] | None = None,
) -> PreTrainedModel:
):
"""
Attach a verifier model for the speculator that is used to attach to
for running inference/training with the speculator and validates the
Expand Down Expand Up @@ -417,14 +417,13 @@ def attach_verifier(
"Must be one of 'full', 'train_only', or None."
)

verifier = self.resolve_verifier(verifier)
self.verifier_attachment_mode = mode or "full"
self.verifier = (
verifier if self.verifier_attachment_mode == "full" else None
self.resolve_verifier(verifier)
if self.verifier_attachment_mode == "full"
else None
) # Expect subclasses to handle references if train_only

return verifier

def detach_verifier(self):
"""
Removes the reference to the attached verifier model and frees up the
Expand Down
30 changes: 15 additions & 15 deletions src/speculators/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
import warnings
from typing import Any, ClassVar, Literal
from typing import Any, ClassVar, Literal, cast

import torch
from pydantic import Field, field_serializer, field_validator, model_validator
Expand Down Expand Up @@ -308,7 +308,7 @@ def attach_verifier(
self,
verifier: str | os.PathLike | PreTrainedModel,
mode: Literal["full", "train_only"] | None = None,
) -> PreTrainedModel:
):
"""
Attach a verifier model to the EagleSpeculator for speculative decoding.
Utilizes the verifier's embed_tokens, rotary_emb, and lm_head layers
Expand Down Expand Up @@ -349,25 +349,25 @@ def attach_verifier(
perform generation until a full verifier is attached.
:return: The PreTrainedModel instance for the verifier that was attached.
"""
verifier = super().attach_verifier(
verifier=verifier,
mode=mode,
)
super().attach_verifier(verifier=verifier, mode=mode)

# Extract layers from the verifier model
if self.verifier_attachment_mode == "train_only":
verifier_model = self.resolve_verifier(verifier)
elif self.verifier_attachment_mode == "full":
verifier_model = cast("PreTrainedModel", self.verifier)
else:
return

if hasattr(verifier, "model"):
self.embed_tokens = verifier.model.embed_tokens # type: ignore[assignment,union-attr]
self.rotary_emb = verifier.model.rotary_emb # type: ignore[assignment,union-attr]
if hasattr(verifier_model, "model"):
self.embed_tokens = verifier_model.model.embed_tokens # type: ignore[assignment,union-attr]
self.rotary_emb = verifier_model.model.rotary_emb # type: ignore[assignment,union-attr]
else:
# Bare model structure
self.embed_tokens = verifier.embed_tokens # type: ignore[assignment,attr-defined]
self.rotary_emb = verifier.rotary_emb # type: ignore[assignment,attr-defined]
self.embed_tokens = verifier_model.embed_tokens # type: ignore[assignment,attr-defined]
self.rotary_emb = verifier_model.rotary_emb # type: ignore[assignment,attr-defined]

# lm_head is always at the top level of the verifier
self.lm_head = verifier.lm_head # type: ignore[assignment,attr-defined]

return verifier
self.lm_head = verifier_model.lm_head # type: ignore[assignment,attr-defined]

def detach_verifier(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,15 @@ def test_speculator_model_attach_verifier_invalid(
):
model.attach_verifier(123) # type: ignore[arg-type]

model = SpeculatorTestModel(config=speculator_model_test_config)
# Invalid attachment mode
with pytest.raises(
ValueError, match="Invalid verifier_attachment_mode: invalid_mode"
):
model.attach_verifier(verifier=None, mode="invalid_mode") # type: ignore[arg-type]

# Attaching when not in detached mode
model = SpeculatorTestModel(config=speculator_model_test_config)
model.verifier_attachment_mode = "full"
with pytest.raises(
RuntimeError,
Expand Down