-
Notifications
You must be signed in to change notification settings - Fork 3.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for uint8_t as data type for GatherBlockQuantized #24239
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
Outdated
Show resolved
Hide resolved
…d.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…d.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
in theory we'd need to rev to opdef but I belief there is no harm in this case: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall there be test?
Yes there shall, can you tell me where the existing tests are and how to run them locally so I can add to it ? I am new to making ORT CPU changes. |
Are the tests here that tests seems to indicate that it I would fail because I added uint8_t support, yet I am passing the CI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe some .md files under the docs should be updated as well.
Could you point me to it please, I see a documentation string in contrib_defs.cc. That has been updated, is there some other documentation. |
Done updated |
Description
This change adds support for GatherBlockQuantized to use uin8_t as data's type with the same semantics as MatMulNBits. Zero_Points and Gather Axis other than 0 are not yet supported, in order to keep the change scoped.
Motivation and Context
With the newer llama models like Phi4 trained with shared embeddings, the weights of the lm_head matrix and the embeddings table are exactly the same. These embeddings are huge, unquantized embeddings are 1.2GB in Phi4 mini instruct, at int4 quantization the weights are still 300MB. We can go a step further and have these two ops the lm_head matmulnbits and GatherBlockQuantized share the same weights, that would save 300MB on the model size.
The two things that hinder that are the shape expectations for GatherBlockQuantized and the data type supported for data in GatherBlockQuantized. The shape can be solved via a simple reshape op, but the data type needs code changes and that is what this change does.
Here is Phi4 modified with shared weights between lm_head and matmulnbits, this model is just 2.1GB on disk.
