Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
26867ba
Test GPT_OSS files through porter
laxmareddyp Sep 5, 2025
f1c055b
generate API and moved files to respective folders
laxmareddyp Sep 6, 2025
d4da96c
Fix format issues
laxmareddyp Sep 6, 2025
b14cfb5
Add gpt_oss to preset loader and Fix format issues
laxmareddyp Sep 6, 2025
b675610
Add gpt_oss to preset loader
laxmareddyp Sep 6, 2025
8cf71ce
generated files through 2.5-pro model
laxmareddyp Sep 8, 2025
2242ef4
Format fix
laxmareddyp Sep 10, 2025
eb25d19
Add converter, RoPE update
laxmareddyp Sep 11, 2025
ba50a9f
Fix format
laxmareddyp Sep 11, 2025
1854d80
Fix BPE tests
laxmareddyp Sep 12, 2025
76139cd
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 12, 2025
00ec305
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 13, 2025
9447990
Update converter
laxmareddyp Sep 13, 2025
340aa85
Fix converter, checkpoints conversion and attention
laxmareddyp Sep 13, 2025
b02cfea
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Sep 24, 2025
47dcdda
Fix the parameter count and debug code
laxmareddyp Sep 24, 2025
5e16f80
Add dequantization logic to converter
laxmareddyp Sep 25, 2025
79c5664
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Oct 9, 2025
59b6930
Add YaRN support,Fix Serialisation,Fix dequantization
laxmareddyp Oct 9, 2025
8d3a658
Merge branch 'keras-team:master' into test_gpt_oss_model
laxmareddyp Nov 11, 2025
d9396c6
Fixed several pytest tests
laxmareddyp Nov 11, 2025
4a63e85
Address gpt_oss_causal_lm tests
laxmareddyp Nov 12, 2025
285253f
Fix format issues
laxmareddyp Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_backbone import (
GptOssBackbone as GptOssBackbone,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm import (
GptOssCausalLM as GptOssCausalLM,
)
from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import (
GptOssCausalLMPreprocessor as GptOssCausalLMPreprocessor,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.hgnetv2.hgnetv2_backbone import (
HGNetV2Backbone as HGNetV2Backbone,
)
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_tokenizer import (
GPTNeoXTokenizer as GPTNeoXTokenizer,
)
from keras_hub.src.models.gpt_oss.gpt_oss_tokenizer import (
GptOssTokenizer as GptOssTokenizer,
)
from keras_hub.src.models.llama.llama_tokenizer import (
LlamaTokenizer as LlamaTokenizer,
)
Expand Down
19 changes: 19 additions & 0 deletions keras_hub/src/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The KerasHub Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone
from keras_hub.src.models.gpt_oss.gpt_oss_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, GptOssBackbone)
274 changes: 274 additions & 0 deletions keras_hub/src/models/gpt_oss/gpt_oss_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright 2024 The KerasHub Authors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove copyright banner

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import keras
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer


class GptOssAttention(keras.layers.Layer):
"""A cached attention layer with sliding window and sink tokens.

This layer implements the attention mechanism described in the GPT-OSS
paper. It includes grouped-query attention, rotary position embeddings,
sliding window attention, and sink tokens for improved performance on
long sequences.

Args:
num_query_heads (int): The number of query attention heads.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow type hints format from other files in the repo.

arg_name: type_hint. description

num_key_value_heads (int): The number of key and value attention
heads.
rope_max_wavelength (int, optional): The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor (float, optional): The scaling factor for the
rotary position embedding. Defaults to 1.0.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window (int, optional): The size of the sliding window.
Defaults to 4096.
dropout (float, optional): The dropout rate. Defaults to 0.
"""

def __init__(
self,
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
kernel_initializer="glorot_uniform",
sliding_window=4096,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout

self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self.rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# k = key/value length
# m = model dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is model dim?

# u = num query heads
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
self._head_dim = self._hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(
None,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)

self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)

self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)

self.sinks = self.add_weight(
shape=(self.num_query_heads,),
initializer="random_normal",
dtype=self.dtype,
name="sinks",
)

self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"

self.built = True

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)

query = self.query_dense(hidden_states)

# Compute RoPE for queries
query = self.rotary_embedding_layer(query, start_index=start_index)

def _compute_key_value(x):
key, value = self.key_dense(x), self.value_dense(x)
# Compute RoPE for keys
key = self.rotary_embedding_layer(key, start_index=start_index)
return key, value

if cache is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache logic for KerasHub is located in causal_lm file.
example:

def _build_cache(self, token_ids):

key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _compute_attention(self, query, key, value, attention_mask=None):
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)

if attention_mask is not None:
# The mask is a boolean tensor, True for positions to be masked.
# We add a large negative number to the masked positions.
adder = ops.cast(
ops.iinfo(self.compute_dtype).min, self.compute_dtype
)
attention_scores = ops.where(
attention_mask[:, None, None, :], adder, attention_scores
)

# Handle sink tokens by concatenating them to the logits.
b = ops.shape(query)[0]
q = ops.shape(query)[1]
sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)

# Stabilize logits before softmax for numerical stability.
max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
max_logits = ops.stop_gradient(max_logits)
combined_logits = combined_logits - max_logits

probs = ops.softmax(combined_logits, axis=-1)

# Remove the sink probabilities before computing the output.
attention_scores = probs[..., :-1]
attention_scores = ops.cast(attention_scores, self.compute_dtype)

attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self.sliding_window,
"dropout": self.dropout,
}
)
return config
Loading
Loading