Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 66 additions & 76 deletions examples/models/lfm2/short_conv.py
Original file line number Diff line number Diff line change
@@ -1,112 +1,102 @@
from typing import Optional
from __future__ import annotations

import torch
from executorch.examples.models.llama.attention import ForwardOptions
from executorch.examples.models.llama.feed_forward import FeedForward

from executorch.examples.models.llama.norm import RMSNorm
from torch import nn


class ShortConv(nn.Module):
def __init__(
self,
dim: int,
L_cache: int = 3,
bias: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""Depthwise short convolution with dual state management.

Supports two modes:
1. State-as-IO: caller passes conv_state in and receives new state back.
Required for AOTI which cannot re-trace mutable buffer mutations.
2. Internal buffer: uses register_buffer + copy_() for XNNPack/portable
backends where mutable buffers are handled natively.
"""

def __init__(self, dim: int, L_cache: int = 3, *, bias: bool = False) -> None:
super().__init__()
assert L_cache == 3, f"Manual depthwise conv only supports L_cache=3, got {L_cache}"
self.dim = dim
self.L_cache = L_cache
self.device = device
self.dtype = dtype
self.bias = bias

self.conv = nn.Conv1d(
dim,
dim,
kernel_size=L_cache,
padding=0, ## we don't need padding since we handle it manually
groups=dim,
bias=bias,
)

conv_state = torch.zeros(
1, ## batch size is assumed to be 1 for now
dim,
L_cache - 1,
device="cpu",
)
self.register_buffer("conv_state", conv_state)

## better performance in Executorch with separate projections
self.conv = nn.Conv1d(dim, dim, kernel_size=L_cache, padding=0, groups=dim, bias=bias)
self.B_proj = nn.Linear(dim, dim, bias=bias)
self.C_proj = nn.Linear(dim, dim, bias=bias)
self.x_proj = nn.Linear(dim, dim, bias=bias)

self.out_proj = nn.Linear(dim, dim, bias=bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size, seqlen, dim = x.size()
assert batch_size == 1, "batch_size must be 1"

B = self.B_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
C = self.C_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)
x = self.x_proj(x).transpose(-1, -2) # (batch_size, dim, seq_len)

Bx = B * x # (batch_size, dim, seq_len)
self.register_buffer(
"conv_state",
torch.zeros(1, dim, L_cache - 1),
)

## This is where we handle padding
## By default, the conv_state is initialized to 0.
# So, assuming prefill is done on an empty cache, concatenating conv_state to the beginning of the sequence acts similary to
## using nn.Conv1d(padding=L_cache-1) (for prefill) without no manual padding.
## However, the manual padding has the added benefit of being correct during decode, when the cache is not initialized to 0.
Bx = torch.cat(
[self.conv_state, Bx], dim=-1
) # (batch_size, dim, seq_len + L_cache - 1)
def forward(
self, x: torch.Tensor, conv_state: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
if conv_state is None:
conv_state = self.conv_state

## Update the conv_state
new_conv_state = Bx[
..., -(self.L_cache - 1) :
] # (batch_size, dim, L_cache - 1)
with torch.no_grad():
self.conv_state.copy_(new_conv_state)
B = self.B_proj(x).transpose(-1, -2)
C = self.C_proj(x).transpose(-1, -2)
x = self.x_proj(x).transpose(-1, -2)

conv_out = self.conv(Bx)[..., : x.size(-1)] # (batch_size, dim, seq_len)
y = C * conv_out # (batch_size, dim, seq_len)
Bx = torch.cat([conv_state, B * x], dim=-1)
new_conv_state = Bx[..., -(self.L_cache - 1) :]

y = y.transpose(-1, -2) # (batch_size, seq_len, dim)
y = y.contiguous() # (batch_size, seq_len, dim)
y = self.out_proj(y) # (batch_size, seq_len, dim)
return y
# Manual depthwise conv — Triton has no template for nn.Conv1d
# with groups=dim and dynamic sequence length.
w = self.conv.weight[:, 0, :]
conv_out = Bx[..., :-2] * w[:, 0:1] + Bx[..., 1:-1] * w[:, 1:2] + Bx[..., 2:] * w[:, 2:3]

def reset_cache(self):
self.conv_state.zero_()
y = self.out_proj((C * conv_out).transpose(-1, -2).contiguous())
Comment on lines +50 to +55
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShortConv.forward implements the convolution manually using self.conv.weight, but it ignores self.conv.bias when bias=True. This makes the bias argument silently incorrect. Either add the bias term to conv_out or enforce bias=False (e.g., via an assertion and/or by removing the parameter) to avoid surprising behavior.

Copilot uses AI. Check for mistakes.
return y, new_conv_state


class ShortConvBlock(nn.Module):
def __init__(self, dim: int, hidden_dim: int, norm_eps: float):
def __init__(self, dim: int, hidden_dim: int, norm_eps: float, layer_idx: int = -1) -> None:
super().__init__()
self.L_cache = 3 # hardcode 3 for now
self.conv = ShortConv(dim, self.L_cache, bias=False)
self.layer_idx = layer_idx
self.conv = ShortConv(dim, L_cache=3, bias=False)
self.feed_forward = FeedForward(dim, hidden_dim)
self.ffn_norm = RMSNorm(dim, norm_eps)
# use attention_norm norm instead of operator_norm to unify with TransformerBlock
self.attention_norm = RMSNorm(dim, norm_eps)

def forward(
self,
x,
freqs_cos=None,
freqs_sin=None,
_unused_attn_options: Optional[ForwardOptions] = None,
): # x: 1xN
h = self.conv.forward(self.attention_norm(x))
x: torch.Tensor,
freqs_cos: torch.Tensor | None = None,
freqs_sin: torch.Tensor | None = None,
attn_options: ForwardOptions | None = None,
) -> tuple[torch.Tensor, dict]:
# State-as-IO: read from attn_options if provided (CUDA/AOTI path)
conv_state = None
if attn_options is not None:
conv_states = attn_options.get("conv_states")
if conv_states is not None:
conv_state = conv_states.get(self.layer_idx)

h, new_conv_state = self.conv(self.attention_norm(x), conv_state)
h = x + h
out = h + self.feed_forward(self.ffn_norm(h))
return out, None

def reset_cache(self):
self.conv.reset_cache()
# Write back state
update: dict = {}
if attn_options is not None and "conv_states" in attn_options:
if conv_state is not None:
conv_state.copy_(new_conv_state)
states = dict(attn_options["conv_states"])
states[self.layer_idx] = new_conv_state
update["conv_states"] = states
Comment on lines +88 to +93
Copy link

Copilot AI Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When attn_options contains conv_states, this block mutates the provided state via conv_state.copy_(...) but then stores new_conv_state (a freshly allocated tensor from cat) back into the returned conv_states dict. In Transformer._forward_layers, that returned dict is merged into attn_options_, so the next layer call will read a non-static tensor and can break the intended AOTI "static address" state path. Also, dict(attn_options["conv_states"]) will throw if the key exists but the value is None. Consider: (1) reading conv_states = attn_options.get("conv_states") and ensuring it’s a dict before copying, and (2) if conv_state is provided, keep that same tensor in the returned mapping (after the in-place update) rather than replacing it with new_conv_state.

Copilot uses AI. Check for mistakes.
else:
# XNNPack/portable path: persist via internal buffer
with torch.no_grad():
self.conv.conv_state.copy_(new_conv_state)

return out, update

def reset_cache(self) -> None:
self.conv.conv_state.zero_()
13 changes: 13 additions & 0 deletions examples/models/lfm2_5_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.lfm2_5_vl.convert_weights import convert_weights
from executorch.examples.models.lfm2_5_vl.model import Lfm2p5VlModel

__all__ = [
"convert_weights",
"Lfm2p5VlModel",
]
33 changes: 33 additions & 0 deletions examples/models/lfm2_5_vl/config/lfm2_5_vl_1_6b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"dim": 2048,
"ffn_dim_multiplier": 1,
"hidden_dim": 8192,
"n_heads": 32,
"n_kv_heads": 8,
"n_layers": 16,
"norm_eps": 1e-5,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 65536,
"use_hf_rope": true,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"layer_types": [
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv"
]
}
33 changes: 33 additions & 0 deletions examples/models/lfm2_5_vl/config/lfm2_5_vl_450m_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"dim": 1024,
"ffn_dim_multiplier": 1,
"hidden_dim": 4608,
"n_heads": 16,
"n_kv_heads": 8,
"n_layers": 16,
"norm_eps": 1e-5,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 65536,
"use_hf_rope": true,
"use_qk_norm": true,
"qk_norm_before_rope": true,
"layer_types": [
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv",
"full_attention",
"conv"
]
}
81 changes: 81 additions & 0 deletions examples/models/lfm2_5_vl/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Convert LFM2.5-VL text decoder weights from HuggingFace to ET format."""

from __future__ import annotations

import argparse
from pathlib import Path

import torch
from executorch.examples.models.checkpoint import get_mapped_key
from safetensors.torch import load_file

_LFM2_5_VL_TO_META: dict[str, str] = {
"model.language_model.embed_tokens.weight": "tok_embeddings.weight",
"model.language_model.embedding_norm.weight": "norm.weight",
"model.language_model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.language_model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.language_model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.language_model.layers.{}.self_attn.out_proj.weight": "layers.{}.attention.wo.weight",
"model.language_model.layers.{}.self_attn.q_layernorm.weight": "layers.{}.attention.q_norm_fn.weight",
"model.language_model.layers.{}.self_attn.k_layernorm.weight": "layers.{}.attention.k_norm_fn.weight",
"model.language_model.layers.{}.operator_norm.weight": "layers.{}.attention_norm.weight",
"model.language_model.layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
"model.language_model.layers.{}.feed_forward.w1.weight": "layers.{}.feed_forward.w1.weight",
"model.language_model.layers.{}.feed_forward.w2.weight": "layers.{}.feed_forward.w2.weight",
"model.language_model.layers.{}.feed_forward.w3.weight": "layers.{}.feed_forward.w3.weight",
"model.language_model.layers.{}.conv.conv.weight": "layers.{}.conv.conv.weight",
"model.language_model.layers.{}.conv.out_proj.weight": "layers.{}.conv.out_proj.weight",
"model.language_model.lm_head.weight": "output.weight",
}

_IN_PROJ_SPLITS = ("B_proj", "C_proj", "x_proj")


def lfm2_5_vl_to_meta(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Extract and remap language model weights from a full VL state dict."""
converted: dict[str, torch.Tensor] = {}

for key, value in state_dict.items():
if not key.startswith("model.language_model."):
continue

try:
new_key = get_mapped_key(key, _LFM2_5_VL_TO_META)
except Exception:
new_key = key.removeprefix("model.language_model.")

if new_key.endswith(".conv.in_proj.weight"):
for name, chunk in zip(_IN_PROJ_SPLITS, torch.chunk(value, 3, dim=0)):
converted[new_key.replace("in_proj", name)] = chunk
else:
converted[new_key] = value

if "output.weight" not in converted:
converted["output.weight"] = converted["tok_embeddings.weight"]

return converted


def convert_weights(input_dir: str, output_file: str) -> None:
sd = load_file(str(Path(input_dir) / "model.safetensors"))
sd = lfm2_5_vl_to_meta(sd)
torch.save(sd, output_file)
print(f"Saved {len(sd)} tensors to {output_file}")


def main() -> None:
parser = argparse.ArgumentParser(description="Convert LFM2.5-VL weights to ET format.")
parser.add_argument("input_dir", help="Directory containing model.safetensors.")
parser.add_argument("output", help="Output .pt checkpoint path.")
args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
Loading
Loading