Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
26867ba
Test GPT_OSS files through porter
laxmareddyp Sep 5, 2025
f1c055b
generate API and moved files to respective folders
laxmareddyp Sep 6, 2025
d4da96c
Fix format issues
laxmareddyp Sep 6, 2025
b14cfb5
Add gpt_oss to preset loader and Fix format issues
laxmareddyp Sep 6, 2025
b675610
Add gpt_oss to preset loader
laxmareddyp Sep 6, 2025
8cf71ce
generated files through 2.5-pro model
laxmareddyp Sep 8, 2025
2242ef4
Format fix
laxmareddyp Sep 10, 2025
eb25d19
Add converter, RoPE update
laxmareddyp Sep 11, 2025
ba50a9f
Fix format
laxmareddyp Sep 11, 2025
1854d80
Fix BPE tests
laxmareddyp Sep 12, 2025
76139cd
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 12, 2025
00ec305
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 13, 2025
9447990
Update converter
laxmareddyp Sep 13, 2025
340aa85
Fix converter, checkpoints conversion and attention
laxmareddyp Sep 13, 2025
b02cfea
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 24, 2025
47dcdda
Fix the parameter count and debug code
laxmareddyp Sep 24, 2025
5e16f80
Add dequantization logic to converter
laxmareddyp Sep 25, 2025
79c5664
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Oct 9, 2025
59b6930
Add YaRN support,Fix Serialisation,Fix dequantization
laxmareddyp Oct 9, 2025
8d3a658
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Nov 11, 2025
d9396c6
Fixed several pytest tests
laxmareddyp Nov 11, 2025
4a63e85
Address gpt_oss_causal_lm tests
laxmareddyp Nov 12, 2025
285253f
Fix format issues
laxmareddyp Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
GptOssBackbone as GptOssBackbone,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
GptOssCausalLM as GptOssCausalLM,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
HGNetV2Backbone as HGNetV2Backbone,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.llama.llama_tokenizer import (
LlamaTokenizer as LlamaTokenizer,
)
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone
from keras_hub.src.models.gpt_oss.gpt_oss_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, GptOssBackbone)
313 changes: 313 additions & 0 deletions keras_hub/src/models/gpt_oss/gpt_oss_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
import math

import keras
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer


class CachedGptOssAttention(keras.layers.Layer):
"""A cached attention layer for GPT-OSS with sink tokens and sliding window.

This layer implements the attention mechanism for the GPT-OSS model,
including grouped query attention (GQA),rotary positional embeddings(RoPE)
and a specific handling for "sink" tokens which are added to the attention
logits before softmax. It also supports caching for efficient generation.

Args:
num_query_heads: Number of attention heads for queries.
num_key_value_heads: Number of attention heads for keys and values.
If `num_query_heads != num_key_value_heads`, grouped query attention
is used.
rope_max_wavelength: The maximum wavelength for the rotary embedding.
rope_scaling_factor: Scaling factor for rotary embeddings.
kernel_initializer: Initializer for the dense layer kernels.
sliding_window: The size of the sliding window for attention.
Tokens outside this window are masked. This parameter is used for
configuration but the actual masking should be handled by the
`attention_mask` input.
dropout: Dropout rate for attention probabilities.
use_bias: Whether to include bias terms in the dense projections.
**kwargs: Additional keyword arguments passed to the base Layer class.
"""

def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=4096,
dropout=0,
use_bias=False,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout
self.use_bias = use_bias

if self.num_query_heads % self.num_key_value_heads != 0:
raise ValueError(
f"num_query_heads({self.num_query_heads})must be divisible by"
f"num_key_value_heads ({self.num_key_value_heads})"
)
self.num_key_value_groups = (
self.num_query_heads // self.num_key_value_heads
)
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is model dim?

# u = num query heads
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
self._head_dim = self._hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
use_bias=self.use_bias,
dtype=self.dtype_policy,
name="q_proj",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
use_bias=self.use_bias,
dtype=self.dtype_policy,
name="k_proj",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
use_bias=self.use_bias,
dtype=self.dtype_policy,
name="v_proj",
)
self.value_dense.build(inputs_shape)

stddev = (
self._kernel_initializer.stddev
if hasattr(self._kernel_initializer, "stddev")
else 0.02
)
self.sinks = self.add_weight(
name="sinks",
shape=(self.num_query_heads,),
initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=stddev
),
dtype=self.dtype_policy,
)

self.softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)

self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
use_bias=self.use_bias,
dtype=self.dtype_policy,
name="o_proj",
)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)

self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)

self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)

query = self.query_dense(hidden_states)

# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)

def _compute_key_value(x):
key, value = self.key_dense(x), self.value_dense(x)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value

if cache is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache logic for KerasHub is located in causal_lm file.
example:

def _build_cache(self, token_ids):

key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)
if self.num_key_value_groups > 1:
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask, training=training
)

attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _use_fused_attention_op(self):
# GPT-OSS attention includes "sink" tokens which are added to the logits
# before softmax. The Keras `ops.dot_product_attention` does not support
# this custom modification to the logits. Therefore, we must use the
# manual attention calculation path.
return False

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
# The _use_fused_attention_op is explicitly False for GptOssAttention
# due to the sink token mechanism.

# 1. Calculate raw attention scores
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)

# 2. Apply attention mask (if any)
if attention_mask is not None:
if ops.ndim(attention_mask) == 3:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_scores = attention_scores + attention_mask

# 3. Prepare and concatenate sink tokens
# sinks shape: (num_query_heads,)
sinks_expanded = ops.reshape(
self.sinks, (1, self.num_query_heads, 1, 1)
)
# The attention_scores shape is (batch, num_heads, query_len, key_len)
sinks_expanded = ops.broadcast_to(
sinks_expanded, ops.shape(attention_scores)[:-1] + (1,)
)

# Concatenate attention scores with sinks along the last dimension
# Resulting shape: (batch, num_query_heads, query_len, key_len + 1)
combined_logits = ops.concatenate(
[attention_scores, sinks_expanded], axis=-1
)

# 4. Apply numerical stability clamping before softmax
max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
combined_logits = combined_logits - max_logits

# 5. Apply softmax
# Softmax is applied to the combined logits (scores + sinks)
probs = self.softmax(combined_logits) # self.softmax is float32

# 6. Drop the sink token probability to get final attention weights
# scores = probs[..., :-1]
scores = ops.slice(
probs,
[0, 0, 0, 0],
ops.shape(probs)[:-1] + (ops.shape(probs)[-1] - 1,),
)

# 7. Cast to compute_dtype (dropout is handled outside this method)
attention_weights = ops.cast(scores, self.compute_dtype)

# 8. Compute weighted sum of values
attention_output = ops.einsum(
self._combine_equation, attention_weights, value
)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self.sliding_window,
"dropout": self.dropout,
"use_bias": self.use_bias,
}
)
return config
Loading
Loading