1+ import math
2+ from dataclasses import InitVar , dataclass , field
3+ from typing import Any , Tuple
4+
5+ import jax
6+ import jax .numpy as jnp
7+ from flax import nnx
8+ from flax .typing import Sharding
9+ from jax .experimental import shard_map
10+ from jax .sharding import Mesh
11+ from jax .sharding import PartitionSpec as P
12+ from jaxtyping import Float
13+
14+ from tpu_inference import utils
15+ from tpu_inference .kernels .ragged_paged_attention .v3 .kernel import \
16+ ragged_paged_attention
17+ from tpu_inference .layers .common .attention_metadata import AttentionMetadata
18+ from tpu_inference .layers .jax .base import create_param
19+ from tpu_inference .layers .jax .layers import RMSNorm
20+ from tpu_inference .layers .jax .rope import GptOssRotaryEmbedding
21+
22+ KVCache = Tuple [jax .Array , jax .Array ]
23+
24+ @dataclass (kw_only = True )
25+ class GptOssAttention (nnx .Module ):
26+ """
27+ JAX implementation of the GPT-OSS Attention block
28+ """
29+ hidden_size : int
30+ num_attention_heads : int
31+ num_key_value_heads : int
32+ head_dim : int
33+ dtype : jnp .dtype
34+ rngs : InitVar [nnx .Rngs ]
35+
36+ rope_theta : float
37+ initial_context_length : int = 4096
38+ rope_scaling_factor : float = 32.0
39+ rope_ntk_alpha : float = 1.0
40+ rope_ntk_beta : float = 32.0
41+
42+ query_tnh : P = P ()
43+ keyvalue_skh : P = P ()
44+ attn_o_tnh : P = P ()
45+ dnh_sharding : Sharding = ()
46+ dkh_sharding : Sharding = ()
47+ nhd_sharding : Sharding = ()
48+ n_sharding : Sharding = ()
49+ nh_sharding : Sharding = ()
50+ kh_sharding : Sharding = ()
51+ d_sharding : Sharding = ()
52+
53+ random_init : bool = False
54+ mesh : Mesh
55+
56+ def __post_init__ (self , rngs : nnx .Rngs ):
57+ """Initializes weights, biases, and RoPE module."""
58+ #D, N, K, H = self.hidden_size, self.num_attention_heads, self.num_key_value_heads, self.head_dim
59+
60+ self .sm_scale = 1.0 / (self .head_dim ** 0.5 )
61+
62+ self .sinks_N = create_param (
63+ rngs , shape = (self .num_attention_heads ,), dtype = jnp .float32 ,
64+ sharding = self .n_sharding , random_init = self .random_init
65+ )
66+
67+ # Q, K, V projection kernels
68+ self .kernel_q_DNH = create_param (
69+ rngs , shape = (self .hidden_size , self .num_attention_heads , self .head_dim ),
70+ dtype = self .dtype , sharding = self .dnh_sharding , random_init = self .random_init
71+ )
72+ self .bias_q_NH = create_param (
73+ rngs , shape = (self .num_attention_heads , self .head_dim ),
74+ dtype = self .dtype , sharding = self .nh_sharding , random_init = self .random_init
75+ )
76+ self .kernel_k_DKH = create_param (
77+ rngs , shape = (self .hidden_size , self .num_key_value_heads , self .head_dim ),
78+ dtype = self .dtype , sharding = self .dkh_sharding , random_init = self .random_init
79+ )
80+ self .bias_k_KH = create_param (
81+ rngs , shape = (self .num_key_value_heads , self .head_dim ),
82+ dtype = self .dtype , sharding = self .kh_sharding , random_init = self .random_init
83+ )
84+ self .kernel_v_DKH = create_param (
85+ rngs , shape = (self .hidden_size , self .num_key_value_heads , self .head_dim ),
86+ dtype = self .dtype , sharding = self .dkh_sharding , random_init = self .random_init
87+ )
88+ self .bias_v_KH = create_param (
89+ rngs , shape = (self .num_key_value_heads , self .head_dim ),
90+ dtype = self .dtype , sharding = self .kh_sharding , random_init = self .random_init
91+ )
92+ # Output projection kernel
93+ self .kernel_o_proj_NHD = create_param (
94+ rngs , shape = (self .num_attention_heads , self .head_dim , self .hidden_size ),
95+ dtype = self .dtype , sharding = self .nhd_sharding , random_init = self .random_init
96+ )
97+ self .bias_o_D = create_param (
98+ rngs , shape = (self .hidden_size ,),
99+ dtype = self .dtype , sharding = self .d_sharding , random_init = self .random_init
100+ )
101+
102+ # RoPE Module
103+ self .rope = GptOssRotaryEmbedding (
104+ head_dim = self .head_dim ,
105+ rope_theta = self .rope_theta ,
106+ dtype = self .dtype ,
107+ initial_context_length = self .initial_context_length ,
108+ rope_scaling_factor = self .rope_scaling_factor ,
109+ rope_ntk_alpha = self .rope_ntk_alpha ,
110+ rope_ntk_beta = self .rope_ntk_beta
111+ )
112+
113+ def attention (
114+ self ,
115+ kv_cache : KVCache ,
116+ q_TNH : jax .Array ,
117+ k_SKH : jax .Array ,
118+ v_SKH : jax .Array ,
119+ sinks : jax .Array ,
120+ attention_metadata : AttentionMetadata ,
121+ mesh : Mesh ,
122+ ) -> Tuple [KVCache , jax .Array ]:
123+ """Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
124+ md = attention_metadata
125+ kv_cache_spec = P (None , None , "model" )
126+
127+ in_specs = (
128+ self .query_tnh , # q
129+ self .keyvalue_skh , # k
130+ self .keyvalue_skh , # v
131+ kv_cache_spec , # kv_cache
132+ P (), # md.seq_lens: Replicated
133+ P (), # page_indices_flat: Replicated
134+ P (), # query_start_loc: Replicated
135+ P (), # distribution: Replicated
136+ P (('model' )), # sinks
137+ )
138+ out_specs = (self .attn_o_tnh , kv_cache_spec )
139+
140+ def _ragged_paged_attention_wrapper (* args ):
141+ # Pass the GPT-OSS specific parameters to the kernel
142+ return ragged_paged_attention (
143+ * args ,
144+ sm_scale = self .sm_scale ,
145+ sliding_window = md .sliding_window ,
146+ )
147+
148+ output_TNH , kv_cache = jax .jit (
149+ shard_map .shard_map (
150+ _ragged_paged_attention_wrapper ,
151+ mesh = mesh ,
152+ in_specs = in_specs ,
153+ out_specs = out_specs ,
154+ check_rep = False ,
155+ ))(
156+ q_TNH ,
157+ k_SKH ,
158+ v_SKH ,
159+ kv_cache ,
160+ md .seq_lens ,
161+ md .block_tables ,
162+ md .query_start_loc ,
163+ md .request_distribution ,
164+ sinks ,
165+ )
166+ return kv_cache , output_TNH
167+
168+ def __call__ (self ,
169+ x_TD ,
170+ is_prefill ,
171+ kv_cache : KVCache ,
172+ attention_metadata : AttentionMetadata ,
173+ use_attention_rope : bool = True ):
174+ """Forward pass for the Attention module using 3D kernels."""
175+ md = attention_metadata
176+ x_TD = jnp .asarray (x_TD , self .dtype )
177+
178+ with jax .named_scope ("q_proj" ):
179+ q_TNH = jnp .einsum ("TD,DNH->TNH" , x_TD , self .kernel_q_DNH .value )
180+ q_TNH += self .bias_q_NH .value
181+
182+ with jax .named_scope ("k_proj" ):
183+ k_TKH = jnp .einsum ("TD,DKH->TKH" , x_TD , self .kernel_k_DKH .value )
184+ k_TKH += self .bias_k_KH .value
185+
186+ with jax .named_scope ("v_proj" ):
187+ v_TKH = jnp .einsum ("TD,DKH->TKH" , x_TD , self .kernel_v_DKH .value )
188+ v_TKH += self .bias_v_KH .value
189+
190+ if use_attention_rope :
191+ q_TNH , k_TKH = self .rope (q_TNH , k_TKH , md .input_positions )
192+
193+ with jax .named_scope ("attn_op" ):
194+ # Padding H dim of q,k,v to be the multiple of 128
195+ multiple_of_128 = ((self .head_dim - 1 ) // 128 + 1 ) * 128
196+ q_TNH = jnp .pad (q_TNH , ((0 , 0 ), (0 , 0 ),
197+ (0 , multiple_of_128 - self .head_dim )))
198+ k_TKH = jnp .pad (k_TKH , ((0 , 0 ), (0 , 0 ),
199+ (0 , multiple_of_128 - self .head_dim )))
200+ v_TKH = jnp .pad (v_TKH , ((0 , 0 ), (0 , 0 ),
201+ (0 , multiple_of_128 - self .head_dim )))
202+ new_kv_cache , attn_out_TNH = self .attention (
203+ kv_cache ,
204+ q_TNH ,
205+ k_TKH ,
206+ v_TKH ,
207+ self .sinks_N .value ,
208+ md ,
209+ self .mesh
210+ )
211+ attn_out_TNH = attn_out_TNH [..., :self .head_dim ]
212+
213+ with jax .named_scope ("o_proj" ):
214+ output_TD = jnp .einsum ("TNH,NHD->TD" , attn_out_TNH , self .kernel_o_proj_NHD .value )
215+ output_TD += self .bias_o_D .value
216+
217+ return new_kv_cache , output_TD
0 commit comments