-
Couldn't load subscription status.
- Fork 18
[GPT-OSS] JAX implementation of GPT-OSS #861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
|
|
||
| supported_quantization: list[str] = [ | ||
| "tpu_int8", "compressed-tensors", "awq", "fp8" | ||
| "tpu_int8", "compressed-tensors", "awq", "fp8", "mxfp4" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using unquantized bf16 model, so this change shouldn't be needed right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing it will cause vllm to give errors like: "https://paste.googleplex.com/5778456474419200#l=53". And we will eventually optimize the fp4 version, which will also need it(?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you remove this file from this PR? I'll create a separate PR for the kernel change.
For the model, can you just comment out the code that calls the kernel with the argument attention_sink? I'll uncomment it with the kernel change PR.
Output of the model will be wrong, but if someone wants to try it out before the kernel change has been landed, they can just cherry pick the commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will remove this file and comment out the sink interface when merging the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is ongoing discussion about not having every model implementation have their own layer implementations (like GptOssAttention). Can't you just add feature into existing attention implementation in https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/jax/attention/attention.py? At very least, create a new class that inherits from that Attention class so we can reuse many of the components?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I aimed to create the attention on top of the regular attention, however, most of the methods(sef.__post__init(), self.call(), self.attention()) has to be overwrite(due to sink token, bias for q/k/v) if doing so, which minimize the benefit of inheriting from the base attention.
| return gated_activation * (x_linear + 1) | ||
|
|
||
| @dataclass(kw_only=True) | ||
| class GptOssMoE(nnx.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to comment about attention, is it possible to use existing implementation from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/layers/jax/moe/moe.py ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly, GPT-oss fused gate and up together and use a custimized swiglu funciton, however the base moe create base, up tensors seperately. If we would like to follow the GPT's implantation, we have to re-write all logics of the base moe. So, i think If our future optimization show that spliting the gate/up tensor giving better performance, we could choose to inherit the moe from the existing one(?)
| return ramp_func | ||
|
|
||
| @dataclass(kw_only=True) | ||
| class GptOssRotaryEmbedding(nnx.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment with attention & moe layer.
| "transpose": lambda w, _: w.T, | ||
| } | ||
|
|
||
| mappings = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this kind of string based dictionary mapping looks extremely fragile? Can you give brief summary of what it does and why it needs to be done this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To load the weight, we might need 3 mapping tables, 1. the naming map, i.e. lm_head.weight -> lm_head.input_embedding_table_DV; 2. shape mapping, i.e. 2D Q tensor to 3D Q tensor; 3. transpose, i.e. (Vocab, d_model) -> (d_model, vocal). Alternatively, we could maintain 3 mapping tables for each of the operations as those in DeepSeek: _transpose_map, _weight_shape_map, _loaded_to_standardized_keys . To be simply, i merged those into a single mapping table here.
|
Does |
I don't think |
685707f to
1cf090a
Compare
Signed-off-by: Kyuyeun Kim <[email protected]>
Description
JAX implementation of GPT-OSS
Model loaded with command:
TPU_BACKEND_TYPE=jax vllm serve --model=unsloth/gpt-oss-120b-BF16 --max-model-len=1024 --max-num-batched-tokens 1024 --max-num-seqs=128 --hf-config=openai/gpt-oss-120b --no-enable-prefix-caching --disable-log-requests --gpu-memory-utilization 0.7 --tensor-parallel-size 8 --additional_config='{"skip_quantization": "True"}'
Tests
Compared logits from HF transformer lib implementation with ~1% logits difference for each stage: https://buganizer.corp.google.com/issues/452659743#comment6
Checklist
Before submitting this PR, please make sure: