Skip to content

Commit 6242fc3

Browse files
bzgooglekyuyeunk
authored andcommitted
[GPT-OSS]Initial draft of all blocks and model.py needed for GPT-OSS
1 parent 8c7e7bb commit 6242fc3

File tree

6 files changed

+772
-3
lines changed

6 files changed

+772
-3
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
import math
2+
from dataclasses import InitVar, dataclass, field
3+
from typing import Any, Tuple
4+
5+
import jax
6+
import jax.numpy as jnp
7+
from flax import nnx
8+
from flax.typing import Sharding
9+
from jax.experimental import shard_map
10+
from jax.sharding import Mesh
11+
from jax.sharding import PartitionSpec as P
12+
from jaxtyping import Float
13+
14+
from tpu_inference import utils
15+
from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
16+
ragged_paged_attention
17+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18+
from tpu_inference.layers.jax.base import create_param
19+
from tpu_inference.layers.jax.layers import RMSNorm
20+
from tpu_inference.layers.jax.rope import GptOssRotaryEmbedding
21+
22+
KVCache = Tuple[jax.Array, jax.Array]
23+
24+
@dataclass(kw_only=True)
25+
class GptOssAttention(nnx.Module):
26+
"""
27+
JAX implementation of the GPT-OSS Attention block
28+
"""
29+
hidden_size: int
30+
num_attention_heads: int
31+
num_key_value_heads: int
32+
head_dim: int
33+
dtype: jnp.dtype
34+
rngs: InitVar[nnx.Rngs]
35+
36+
rope_theta: float
37+
initial_context_length: int = 4096
38+
rope_scaling_factor: float = 32.0
39+
rope_ntk_alpha: float = 1.0
40+
rope_ntk_beta: float = 32.0
41+
42+
query_tnh: P = P()
43+
keyvalue_skh: P = P()
44+
attn_o_tnh: P = P()
45+
dnh_sharding: Sharding = ()
46+
dkh_sharding: Sharding = ()
47+
nhd_sharding: Sharding = ()
48+
n_sharding: Sharding = ()
49+
nh_sharding: Sharding = ()
50+
kh_sharding: Sharding = ()
51+
d_sharding: Sharding = ()
52+
53+
random_init: bool = False
54+
mesh: Mesh
55+
56+
def __post_init__(self, rngs: nnx.Rngs):
57+
"""Initializes weights, biases, and RoPE module."""
58+
#D, N, K, H = self.hidden_size, self.num_attention_heads, self.num_key_value_heads, self.head_dim
59+
60+
self.sm_scale = 1.0 / (self.head_dim ** 0.5)
61+
62+
self.sinks_N = create_param(
63+
rngs, shape=(self.num_attention_heads,), dtype=jnp.float32,
64+
sharding=self.n_sharding, random_init=self.random_init
65+
)
66+
67+
# Q, K, V projection kernels
68+
self.kernel_q_DNH = create_param(
69+
rngs, shape=(self.hidden_size, self.num_attention_heads, self.head_dim),
70+
dtype=self.dtype, sharding=self.dnh_sharding, random_init=self.random_init
71+
)
72+
self.bias_q_NH = create_param(
73+
rngs, shape=(self.num_attention_heads, self.head_dim),
74+
dtype=self.dtype, sharding=self.nh_sharding, random_init=self.random_init
75+
)
76+
self.kernel_k_DKH = create_param(
77+
rngs, shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
78+
dtype=self.dtype, sharding=self.dkh_sharding, random_init=self.random_init
79+
)
80+
self.bias_k_KH = create_param(
81+
rngs, shape=(self.num_key_value_heads, self.head_dim),
82+
dtype=self.dtype, sharding=self.kh_sharding, random_init=self.random_init
83+
)
84+
self.kernel_v_DKH = create_param(
85+
rngs, shape=(self.hidden_size, self.num_key_value_heads, self.head_dim),
86+
dtype=self.dtype, sharding=self.dkh_sharding, random_init=self.random_init
87+
)
88+
self.bias_v_KH = create_param(
89+
rngs, shape=(self.num_key_value_heads, self.head_dim),
90+
dtype=self.dtype, sharding=self.kh_sharding, random_init=self.random_init
91+
)
92+
# Output projection kernel
93+
self.kernel_o_proj_NHD = create_param(
94+
rngs, shape=(self.num_attention_heads, self.head_dim, self.hidden_size),
95+
dtype=self.dtype, sharding=self.nhd_sharding, random_init=self.random_init
96+
)
97+
self.bias_o_D = create_param(
98+
rngs, shape=(self.hidden_size,),
99+
dtype=self.dtype, sharding=self.d_sharding, random_init=self.random_init
100+
)
101+
102+
# RoPE Module
103+
self.rope = GptOssRotaryEmbedding(
104+
head_dim=self.head_dim,
105+
rope_theta=self.rope_theta,
106+
dtype=self.dtype,
107+
initial_context_length=self.initial_context_length,
108+
rope_scaling_factor=self.rope_scaling_factor,
109+
rope_ntk_alpha=self.rope_ntk_alpha,
110+
rope_ntk_beta=self.rope_ntk_beta
111+
)
112+
113+
def attention(
114+
self,
115+
kv_cache: KVCache,
116+
q_TNH: jax.Array,
117+
k_SKH: jax.Array,
118+
v_SKH: jax.Array,
119+
sinks: jax.Array,
120+
attention_metadata: AttentionMetadata,
121+
mesh: Mesh,
122+
) -> Tuple[KVCache, jax.Array]:
123+
"""Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
124+
md = attention_metadata
125+
kv_cache_spec = P(None, None, "model")
126+
127+
in_specs = (
128+
self.query_tnh, # q
129+
self.keyvalue_skh, # k
130+
self.keyvalue_skh, # v
131+
kv_cache_spec, # kv_cache
132+
P(), # md.seq_lens: Replicated
133+
P(), # page_indices_flat: Replicated
134+
P(), # query_start_loc: Replicated
135+
P(), # distribution: Replicated
136+
P(('model')), # sinks
137+
)
138+
out_specs = (self.attn_o_tnh, kv_cache_spec)
139+
140+
def _ragged_paged_attention_wrapper(*args):
141+
# Pass the GPT-OSS specific parameters to the kernel
142+
return ragged_paged_attention(
143+
*args,
144+
sm_scale=self.sm_scale,
145+
sliding_window=md.sliding_window,
146+
)
147+
148+
output_TNH, kv_cache = jax.jit(
149+
shard_map.shard_map(
150+
_ragged_paged_attention_wrapper,
151+
mesh=mesh,
152+
in_specs=in_specs,
153+
out_specs=out_specs,
154+
check_rep=False,
155+
))(
156+
q_TNH,
157+
k_SKH,
158+
v_SKH,
159+
kv_cache,
160+
md.seq_lens,
161+
md.block_tables,
162+
md.query_start_loc,
163+
md.request_distribution,
164+
sinks,
165+
)
166+
return kv_cache, output_TNH
167+
168+
def __call__(self,
169+
x_TD,
170+
is_prefill,
171+
kv_cache: KVCache,
172+
attention_metadata: AttentionMetadata,
173+
use_attention_rope: bool = True):
174+
"""Forward pass for the Attention module using 3D kernels."""
175+
md = attention_metadata
176+
x_TD = jnp.asarray(x_TD, self.dtype)
177+
178+
with jax.named_scope("q_proj"):
179+
q_TNH = jnp.einsum("TD,DNH->TNH", x_TD, self.kernel_q_DNH.value)
180+
q_TNH += self.bias_q_NH.value
181+
182+
with jax.named_scope("k_proj"):
183+
k_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_k_DKH.value)
184+
k_TKH += self.bias_k_KH.value
185+
186+
with jax.named_scope("v_proj"):
187+
v_TKH = jnp.einsum("TD,DKH->TKH", x_TD, self.kernel_v_DKH.value)
188+
v_TKH += self.bias_v_KH.value
189+
190+
if use_attention_rope:
191+
q_TNH, k_TKH = self.rope(q_TNH, k_TKH, md.input_positions)
192+
193+
with jax.named_scope("attn_op"):
194+
# Padding H dim of q,k,v to be the multiple of 128
195+
multiple_of_128 = ((self.head_dim - 1) // 128 + 1) * 128
196+
q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
197+
(0, multiple_of_128 - self.head_dim)))
198+
k_TKH = jnp.pad(k_TKH, ((0, 0), (0, 0),
199+
(0, multiple_of_128 - self.head_dim)))
200+
v_TKH = jnp.pad(v_TKH, ((0, 0), (0, 0),
201+
(0, multiple_of_128 - self.head_dim)))
202+
new_kv_cache, attn_out_TNH = self.attention(
203+
kv_cache,
204+
q_TNH,
205+
k_TKH,
206+
v_TKH,
207+
self.sinks_N.value,
208+
md,
209+
self.mesh
210+
)
211+
attn_out_TNH = attn_out_TNH[..., :self.head_dim]
212+
213+
with jax.named_scope("o_proj"):
214+
output_TD = jnp.einsum("TNH,NHD->TD", attn_out_TNH, self.kernel_o_proj_NHD.value)
215+
output_TD += self.bias_o_D.value
216+
217+
return new_kv_cache, output_TD
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import enum
2+
from dataclasses import InitVar, dataclass
3+
from functools import partial
4+
from typing import Optional, Tuple
5+
6+
import jax
7+
import jax.numpy as jnp
8+
from flax import nnx
9+
from flax.typing import Sharding
10+
from jax.sharding import PartitionSpec
11+
from jaxtyping import Float
12+
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
13+
from qwix._src.providers import ptq
14+
15+
from tpu_inference.layers.jax.base import create_param
16+
from tpu_inference.layers.jax.layers import FlaxUtils
17+
from tpu_inference.layers.jax.moe.moe import MoE, Router
18+
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
19+
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
20+
21+
modeling_flax_utils = FlaxUtils()
22+
23+
@dataclass(kw_only=True)
24+
class GptOssRouter(Router):
25+
"""Router module for Mixture-of-Experts (MoE) layers.
26+
27+
This module determines which experts each token should be routed to based on the input.
28+
29+
"""
30+
e_sharding: Sharding = ()
31+
32+
def __post_init__(self, rngs: nnx.Rngs):
33+
"""
34+
Initializes the parent's kernel and adds the new bias parameter.
35+
"""
36+
super().__post_init__(rngs)
37+
38+
self.bias_E = create_param(rngs,
39+
shape=(self.num_experts,),
40+
dtype=self.dtype,
41+
sharding=self.e_sharding,
42+
random_init=self.random_init)
43+
44+
def __call__(self, x_TD: Float):
45+
"""
46+
Overrides the parent's forward pass to include the bias.
47+
"""
48+
x_TD = jnp.asarray(x_TD, self.dtype)
49+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
50+
51+
router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD, self.kernel_DE.value)
52+
53+
router_logits_TE += self.bias_E.value
54+
55+
weights_TX, selected_experts_TX = jax.lax.top_k(
56+
router_logits_TE, self.num_experts_per_tok)
57+
58+
normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype), axis=-1)
59+
60+
return normalized_weights_TX, selected_experts_TX
61+
62+
def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
63+
"""Implements the specific SwiGLU from the golden implementation."""
64+
x_glu, x_linear = x[..., ::2], x[..., 1::2]
65+
66+
x_glu = jnp.clip(x_glu, a_max=limit)
67+
x_linear = jnp.clip(x_linear, a_min=-limit, a_max=limit)
68+
69+
gated_activation = x_glu * jax.nn.sigmoid(alpha * x_glu)
70+
71+
return gated_activation * (x_linear + 1)
72+
73+
@dataclass(kw_only=True)
74+
class GptOssMoE(nnx.Module):
75+
"""
76+
JAX implementation of the GPT-OSS Mixture-of-Experts MLP block.
77+
"""
78+
dtype: jnp.dtype
79+
hidden_size: int
80+
intermediate_size_moe: int
81+
num_local_experts: int
82+
router: GptOssRouter
83+
rngs: InitVar[nnx.Rngs]
84+
85+
swiglu_limit: float = 7.0
86+
swiglu_alpha: float = 1.702
87+
88+
# Sharding specifications
89+
activation_ffw_td: Sharding
90+
edf_sharding: Sharding
91+
efd_sharding: Sharding
92+
ed_sharding: Sharding
93+
94+
random_init: bool = False
95+
96+
97+
98+
def __call__(self, x_TD: Float) -> Float:
99+
"""Performs the forward pass for the GPT-OSS MoE layer."""
100+
x_TD = jnp.asarray(x_TD, self.dtype)
101+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
102+
103+
weights_TX, indices_TX = self.router(x_TD)
104+
105+
one_hot_mask_TXE = jax.nn.one_hot(indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
106+
combined_weights_TE = jnp.sum(one_hot_mask_TXE * weights_TX[..., None], axis=1)
107+
108+
# First MLP layer (up-projection)
109+
with jax.named_scope("MLP #1"):
110+
up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD, self.mlp1_weight_EDF2.value)
111+
up_proj_TEF2 += self.mlp1_bias_EF2.value
112+
113+
fuse_TEF = _swiglu(up_proj_TEF2, alpha=self.swiglu_alpha, limit=self.swiglu_limit)
114+
115+
# Second MLP layer (down-projection)
116+
with jax.named_scope("MLP #2"):
117+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF, self.mlp2_weight_EFD.value)
118+
down_proj_TED += self.mlp2_bias_ED.value
119+
120+
# Weighted sum of expert outputs
121+
with jax.named_scope("sum"):
122+
output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED, combined_weights_TE)
123+
124+
return output_TD.astype(self.dtype)
125+
126+
def __post_init__(self, rngs: nnx.Rngs):
127+
"""Initializes all weights and biases for the MoE block."""
128+
D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts
129+
130+
# MLP #1 Weights (Combined Gate and Up-projection) and Bias
131+
self.mlp1_weight_EDF2 = create_param(
132+
rngs, shape=(E, D, F * 2), dtype=self.dtype,
133+
sharding=self.edf_sharding, random_init=self.random_init
134+
)
135+
self.mlp1_bias_EF2 = create_param(
136+
rngs, shape=(E, F * 2), dtype=self.dtype,
137+
sharding=self.ed_sharding, random_init=self.random_init
138+
)
139+
140+
# MLP #2 Weights (Down-projection) and Bias
141+
self.mlp2_weight_EFD = create_param(
142+
rngs, shape=(E, F, D), dtype=self.dtype,
143+
sharding=self.efd_sharding, random_init=self.random_init
144+
)
145+
self.mlp2_bias_ED = create_param(
146+
rngs, shape=(E, D), dtype=self.dtype,
147+
sharding=self.ed_sharding, random_init=self.random_init
148+
)

0 commit comments

Comments
 (0)