Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/fairseq2/assets/cards/models/olmo2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

name: olmo2
model_family: olmo2
checkpoint: "hg://allenai/OLMo-2-0425-1B"
tokenizer: "hg://allenai/OLMo-2-0425-1B"
tokenizer_family: olmo2

---

name: olmo2-0425-1b
base: olmo2
model_arch: olmo2-0425-1b

---

name: olmo2-1124-7b
base: olmo2
model_arch: olmo2-1124-7b
checkpoint: "hg://allenai/OLMo-2-1124-7B"
tokenizer: "hg://allenai/OLMo-2-1124-7B"

---

name: olmo2-1124-13b
base: olmo2
model_arch: olmo2-1124-13b
checkpoint: "hg://allenai/OLMo-2-1124-13B"
tokenizer: "hg://allenai/OLMo-2-1124-13B"
22 changes: 22 additions & 0 deletions src/fairseq2/composition/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@
create_nllb_model,
register_nllb_configs,
)
from fairseq2.models.olmo2 import (
OLMO2_FAMILY,
OLMO2Config,
convert_olmo2_state_dict,
create_olmo2_model,
register_olmo2_configs,
)
from fairseq2.models.qwen import (
QWEN_FAMILY,
QwenConfig,
Expand Down Expand Up @@ -296,6 +303,21 @@ def _register_model_families(container: DependencyContainer) -> None:

register_nllb_configs(container)

# OLMo2
register_model_family(
container,
OLMO2_FAMILY,
kls=TransformerLM,
config_kls=OLMO2Config,
factory=create_olmo2_model,
state_dict_converter=convert_olmo2_state_dict,
compiler=compile_transformer_lm,
fsdp_applier=apply_fsdp_to_transformer_lm,
layerwise_ac_applier=apply_ac_to_transformer_lm,
)

register_olmo2_configs(container)

# Qwen
register_model_family(
container,
Expand Down
15 changes: 15 additions & 0 deletions src/fairseq2/composition/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
NllbTokenizerConfig,
load_nllb_tokenizer,
)
from fairseq2.models.olmo2 import (
OLMO2_FAMILY,
OlmoTokenizer,
OlmoTokenizerConfig,
load_olmo_tokenizer,
)
from fairseq2.models.qwen import (
QWEN_FAMILY,
QwenTokenizer,
Expand Down Expand Up @@ -163,6 +169,15 @@ def _register_tokenizer_families(container: DependencyContainer) -> None:
loader=load_nllb_tokenizer,
)

# OLMo2
register_tokenizer_family(
container,
OLMO2_FAMILY,
kls=OlmoTokenizer,
config_kls=OlmoTokenizerConfig,
loader=load_olmo_tokenizer,
)

# S2T Transformer
register_tokenizer_family(
container,
Expand Down
25 changes: 25 additions & 0 deletions src/fairseq2/models/olmo2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from fairseq2.models.olmo2.attention import (
OLMO2MultiheadAttention as OLMO2MultiheadAttention,
)
from fairseq2.models.olmo2.config import OLMO2_FAMILY as OLMO2_FAMILY
from fairseq2.models.olmo2.config import OLMO2Config as OLMO2Config
from fairseq2.models.olmo2.config import (
register_olmo2_configs as register_olmo2_configs,
)
from fairseq2.models.olmo2.factory import OLMO2Factory as OLMO2Factory
from fairseq2.models.olmo2.factory import create_olmo2_model as create_olmo2_model
from fairseq2.models.olmo2.hub import get_olmo2_model_hub as get_olmo2_model_hub
from fairseq2.models.olmo2.interop import (
convert_olmo2_state_dict as convert_olmo2_state_dict,
)
from fairseq2.models.olmo2.tokenizer import OlmoTokenizer as OlmoTokenizer
from fairseq2.models.olmo2.tokenizer import OlmoTokenizerConfig as OlmoTokenizerConfig
from fairseq2.models.olmo2.tokenizer import load_olmo_tokenizer as load_olmo_tokenizer
151 changes: 151 additions & 0 deletions src/fairseq2/models/olmo2/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""OLMO2-specific attention module with Q/K normalization applied before reshaping.

Note: OLMO2MultiheadAttention inherits from StandardMultiheadAttention (marked @final)
because the only difference is the order of normalization in _project_q() and _project_kv().
Reimplementing the entire class would duplicate ~150 lines of boilerplate code for __init__,
projection setup, and forward logic. The type checker warning is suppressed as this is a
legitimate architectural need specific to OLMO2's design.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

from torch import Tensor
from typing_extensions import override

from fairseq2.data_type import DataType
from fairseq2.device import Device
from fairseq2.gang import Gangs
from fairseq2.models.transformer import StandardMultiheadAttention
from fairseq2.models.transformer.sdpa.base import SDPA
from fairseq2.nn import (
BatchLayout,
IncrementalStateBag,
LayerNorm,
Linear,
PositionEncoder,
Projection,
)


class OLMO2MultiheadAttention(StandardMultiheadAttention): # type: ignore[misc]
"""OLMO2 Multi-head Attention with Q/K normalization applied BEFORE reshaping.

The key difference from StandardMultiheadAttention is the order of operations:
- Standard: Project → Reshape → Normalize → RoPE
- OLMO2: Project → Normalize → Reshape → RoPE

This is why OLMO2's Q/K norm weights have shape [2048] (full projection) instead
of [128] (head_dim).
"""

def __init__(
self,
model_dim: int,
num_heads: int,
sdpa: SDPA,
*,
head_dim: int | None = None,
num_key_value_heads: int | None = None,
kv_dim: int | None = None,
q_proj: Projection | None = None,
k_proj: Projection | None = None,
v_proj: Projection | None = None,
qkv_proj_init_fn: Callable[[Linear], None] | None = None,
q_norm: LayerNorm | None = None,
k_norm: LayerNorm | None = None,
pos_encoder: PositionEncoder | None = None,
output_proj: Projection | None = None,
output_proj_init_fn: Callable[[Linear], None] | None = None,
bias: bool = True,
output_proj_bias: bool | None = None,
state_factory: Any = None,
gangs: Gangs | None = None,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
"""Initialize OLMO2 Multi-head Attention.

All parameters are passed to StandardMultiheadAttention, but the normalization
order is different.
"""
super().__init__(
model_dim=model_dim,
num_heads=num_heads,
sdpa=sdpa,
head_dim=head_dim,
num_key_value_heads=num_key_value_heads,
kv_dim=kv_dim,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
qkv_proj_init_fn=qkv_proj_init_fn,
q_norm=q_norm,
k_norm=k_norm,
pos_encoder=pos_encoder,
output_proj=output_proj,
output_proj_init_fn=output_proj_init_fn,
bias=bias,
output_proj_bias=output_proj_bias,
state_factory=state_factory,
gangs=gangs,
device=device,
dtype=dtype,
)

@override
def _project_q(
self,
seqs: Tensor,
seqs_layout: BatchLayout,
state_bag: IncrementalStateBag | None = None,
) -> Tensor:
# (N, S, M) -> (N, S, K_proj)
q = self.q_proj(seqs)

# OLMO2-specific: Apply normalization BEFORE reshaping
if self.q_norm is not None:
q = self.q_norm(q)

# Reshape (N, S, K_proj) -> (N, S, H, K_h)
q = q.unflatten(-1, (-1, self.head_dim))

if self.pos_encoder is not None:
q = self.pos_encoder(q, seqs_layout, state_bag=state_bag)

return q

@override
def _project_kv(
self,
keys: Tensor,
keys_layout: BatchLayout,
values: Tensor,
state_bag: IncrementalStateBag | None = None,
) -> tuple[Tensor, Tensor]:
# (N, S, K) -> (N, S, K_proj)
k = self.k_proj(keys)
# (N, S, V) -> (N, S, V_proj)
v = self.v_proj(values)

# OLMO2-specific: Apply normalization BEFORE reshaping
if self.k_norm is not None:
k = self.k_norm(k)

# Reshape (N, S, K_proj) -> (N, S, H, K_h)
k = k.unflatten(-1, (-1, self.head_dim))
# Reshape (N, S, V_proj) -> (N, S, H, V_h)
v = v.unflatten(-1, (-1, self.head_dim))

if self.pos_encoder is not None:
k = self.pos_encoder(k, keys_layout, state_bag=state_bag)

return k, v
Loading
Loading