Skip to content

Commit

Permalink
Add option to load checkpoints with transposed Gating Einsum.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 734225807
  • Loading branch information
casaro authored and Flax Authors committed Mar 6, 2025
1 parent bf4c322 commit ceb7aed
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Iterable
import dataclasses
import enum
import functools
from typing import Any

from flax import nnx
Expand Down Expand Up @@ -65,6 +66,7 @@ class TransformerConfig:
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
attn_logits_soft_cap: float | None = None
transpose_gating_einsum: bool = False
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
use_qk_norm: bool = False
Expand Down Expand Up @@ -205,6 +207,7 @@ def gemma_9b(cls):


def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]:
"""Maps linen variable names to nnx variable names."""
new_key = []
for k in key:
if k.startswith('layer_'):
Expand All @@ -228,8 +231,12 @@ def _assign_linen_params_to_nnx_state(
state: dict[tuple[str, ...], Any],
mapped_path: tuple[str | int, ...],
val: Any,
transpose_gating_einsum: bool,
) -> dict[tuple[str, ...], Any]:
"""Splits and maybe transposes gate_proj."""
if 'gate_proj' in mapped_path:
if transpose_gating_einsum:
val = jnp.swapaxes(val, 1, 2)
state[mapped_path].value = val[0]
state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1]
else:
Expand All @@ -246,11 +253,15 @@ def from_params(
) -> Transformer:
if config is None:
config = TransformerConfig.from_params(params)
assign_val_fn = functools.partial(
_assign_linen_params_to_nnx_state,
transpose_gating_einsum=config.transpose_gating_einsum,
)
return helpers.module_from_linen_variables(
module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)),
variables=params['transformer'],
map_key_fn=_map_linen_var_names,
assign_val_fn=_assign_linen_params_to_nnx_state,
assign_val_fn=assign_val_fn,
)

def __init__(
Expand Down

0 comments on commit ceb7aed

Please sign in to comment.