Skip to content

Conversation

@bzgoogle
Copy link
Collaborator

@bzgoogle bzgoogle commented Oct 14, 2025

Description

JAX implementation of GPT-OSS

Model loaded with command:
TPU_BACKEND_TYPE=jax vllm serve --model=unsloth/gpt-oss-120b-BF16 --max-model-len=1024 --max-num-batched-tokens 1024 --max-num-seqs=128 --hf-config=openai/gpt-oss-120b --no-enable-prefix-caching --disable-log-requests --gpu-memory-utilization 0.7 --tensor-parallel-size 8 --additional_config='{"skip_quantization": "True"}'

Tests

Compared logits from HF transformer lib implementation with ~1% logits difference for each stage: https://buganizer.corp.google.com/issues/452659743#comment6

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@bzgoogle bzgoogle requested review from kyuyeunk and qihqi October 17, 2025 21:14
@bzgoogle bzgoogle marked this pull request as ready for review October 17, 2025 21:19

supported_quantization: list[str] = [
"tpu_int8", "compressed-tensors", "awq", "fp8"
"tpu_int8", "compressed-tensors", "awq", "fp8", "mxfp4"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using unquantized bf16 model, so this change shouldn't be needed right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing it will cause vllm to give errors like: "https://paste.googleplex.com/5778456474419200#l=53". And we will eventually optimize the fp4 version, which will also need it(?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this file from this PR? I'll create a separate PR for the kernel change.

For the model, can you just comment out the code that calls the kernel with the argument attention_sink? I'll uncomment it with the kernel change PR.

Output of the model will be wrong, but if someone wants to try it out before the kernel change has been landed, they can just cherry pick the commit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will remove this file and comment out the sink interface when merging the PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is ongoing discussion about not having every model implementation have their own layer implementations (like GptOssAttention). Can't you just add feature into existing attention implementation in https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/jax/attention/attention.py? At very least, create a new class that inherits from that Attention class so we can reuse many of the components?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I aimed to create the attention on top of the regular attention, however, most of the methods(sef.__post__init(), self.call(), self.attention()) has to be overwrite(due to sink token, bias for q/k/v) if doing so, which minimize the benefit of inheriting from the base attention.

return gated_activation * (x_linear + 1)

@dataclass(kw_only=True)
class GptOssMoE(nnx.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to comment about attention, is it possible to use existing implementation from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/jax/moe/moe.py ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly, GPT-oss fused gate and up together and use a custimized swiglu funciton, however the base moe create base, up tensors seperately. If we would like to follow the GPT's implantation, we have to re-write all logics of the base moe. So, i think If our future optimization show that spliting the gate/up tensor giving better performance, we could choose to inherit the moe from the existing one(?)

return ramp_func

@dataclass(kw_only=True)
class GptOssRotaryEmbedding(nnx.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment with attention & moe layer.

"transpose": lambda w, _: w.T,
}

mappings = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this kind of string based dictionary mapping looks extremely fragile? Can you give brief summary of what it does and why it needs to be done this way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To load the weight, we might need 3 mapping tables, 1. the naming map, i.e. lm_head.weight -> lm_head.input_embedding_table_DV; 2. shape mapping, i.e. 2D Q tensor to 3D Q tensor; 3. transpose, i.e. (Vocab, d_model) -> (d_model, vocal). Alternatively, we could maintain 3 mapping tables for each of the operations as those in DeepSeek: _transpose_map, _weight_shape_map, _loaded_to_standardized_keys . To be simply, i merged those into a single mapping table here.

@karan
Copy link
Collaborator

karan commented Oct 21, 2025

Does NEW_MODEL_DESIGN=True TPU_BACKEND_TYPE=jax actually load the jax implementation? I thought we needed to MODEL_IMPL_TYPE=flax_nnx (ref).

@kyuyeunk
Copy link
Collaborator

Does NEW_MODEL_DESIGN=True TPU_BACKEND_TYPE=jax actually load the jax implementation? I thought we needed to MODEL_IMPL_TYPE=flax_nnx (ref).

I don't think NEW_MODEL_DESIGN=True TPU_BACKEND_TYPE=jax is needed, but I do know that we don't need MODEL_IMPL_TYPE=flax_nnx because that's the default behavior.

@kyuyeunk kyuyeunk force-pushed the gpt-oss-exp branch 2 times, most recently from 685707f to 1cf090a Compare October 24, 2025 05:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants