Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
If `zero_points` is not provided, 0 is the zero point.
If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why default zero point is 8 for uint8? That does not sound reasonable to me.
Normally, the default is the middle value 2^(bits - 1), so 128 for 8 bits, and 8 for 4 bits.

Maybe add a description that this operator only supports 4 bits.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this uint8 stores two packed uint4s because this is how matmulnbits works. To resolve this issue, I was recently discussing adding a bits attribute - that would let the uint8_t be intepretted as packed uint4s or a single uint8.

3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
5. For uint8 data, the `gather_axis` must be 0.

#### Version

Expand Down Expand Up @@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(int4), tensor(uint4)</dt>
<dt><tt>T1</tt> : tensor(int4), tensor(uint4), tensor(uint8)</dt>
<dd>Constrain quantized types.</dd>
<dt><tt>T2</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain dequantized types.</dd>
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ Do not modify directly.*
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedGemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GatherBlockQuantized|*in* data:**T1**<br> *in* indices:**Tind**<br> *in* scales:**T2**<br> *in* zero_points:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)<br/> **T2** = tensor(float), tensor(float16)<br/> **Tind** = tensor(int32), tensor(int64)|
|GatherBlockQuantized|*in* data:**T1**<br> *in* indices:**Tind**<br> *in* scales:**T2**<br> *in* zero_points:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)<br/> **Tind** = tensor(int32), tensor(int64)|
|GatherND|*in* data:**T**<br> *in* indices:**Tind**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized);
Expand Down Expand Up @@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized)>,
Expand Down
45 changes: 37 additions & 8 deletions onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
namespace onnxruntime {
namespace contrib {

namespace {
template <typename T1>
int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) {
return static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
}

template <>
int32_t GetDataElement<uint8_t>(const uint8_t* data_ptr, int64_t data_idx) {
const uint8_t data_val_u8 = data_ptr[data_idx >> 1];
// Weights are stored as (nibble2)(nibble1) in uint8_t.
auto data_val = static_cast<int32_t>((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F));
return data_val;
}
} // namespace

template <typename T1, typename Tind>
class GatherBlockQuantized : public OpKernel {
public:
Expand Down Expand Up @@ -98,6 +113,12 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
for (int64_t i = p.gather_axis + 1; i < static_cast<int64_t>(data_rank); ++i)
shape.push_back(data_shape[narrow<size_t>(i)]);

// When data is stored as uint8_t, each element has two int4 values.
// The shape in the onnx model reflects that by having the last dimension be half the number of values.
// Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
// However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
uint32_t components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
shape[shape.size() - 1] = shape.back() * components;
p.output_tensor = context->Output(0, TensorShape(std::move(shape)));

// validate quantization parameters
Expand All @@ -106,7 +127,7 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
"data and scales must have the same rank.");
for (size_t i = 0; i < data_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(i == static_cast<size_t>(p.quantize_axis)
? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i]
? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i]
: data_shape[i] == scales_shape[i],
"data and scales do not match shapes.");
}
Expand Down Expand Up @@ -165,16 +186,22 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
int64_t output_idx = output_idx_base;
int64_t data_idx = data_idx_base;
for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) {
auto data_val = static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
auto data_val = GetDataElement(data_ptr, data_idx);

int64_t x = data_idx / quantize_full_block;
int64_t y = data_idx % quantize_full_block / quantize_N;
int64_t z = data_idx % quantize_N;
int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
auto scale_val = static_cast<float>(scales_ptr[scale_idx]);
auto zp_val = static_cast<int32_t>(zero_points_ptr
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
: 0);
int32_t zp_val;
if constexpr (std::is_same_v<T1, uint8_t>) {
// The default zero point for uint8 weights as stored by MatMulNBits op is 8.
zp_val = 8;
} else {
zp_val = static_cast<int32_t>(zero_points_ptr
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
: 0);
}

output_ptr[output_idx] = static_cast<T2>(static_cast<float>(data_val - zp_val) * scale_val);
}
Expand Down Expand Up @@ -205,7 +232,7 @@ template <typename T1, typename Tind>
Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));

auto components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
const auto& data_shape = p.data_tensor->Shape();
// re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
// re-shape the indices tensor to [gather_N]
Expand All @@ -215,7 +242,7 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
// 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N],
// 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :],
// 4> pick the element from the block: value_i = data_blk[blk_ele_i]
const int64_t gather_block = data_shape.SizeFromDimension(SafeInt<size_t>(p.gather_axis) + 1);
const int64_t gather_block = data_shape.SizeFromDimension(SafeInt<size_t>(p.gather_axis) + 1) * components;
const int64_t gather_axis_dim = data_shape[narrow<size_t>(p.gather_axis)];
const int64_t gather_M = data_shape.SizeToDimension(narrow<size_t>(p.gather_axis));
const int64_t gather_N = p.indices_tensor->Shape().Size();
Expand All @@ -229,7 +256,7 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
// data_i % (quantize_axis_dim * quantize_N) / quantize_N,
// data_i % quantize_N)
// 4> get scale index: (x, y / block_size_, z)
const int64_t quantize_axis_dim = data_shape[narrow<size_t>(p.quantize_axis)];
const int64_t quantize_axis_dim = data_shape[narrow<size_t>(p.quantize_axis)] * components;
const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt<size_t>(p.quantize_axis) + 1);

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
Expand Down Expand Up @@ -273,6 +300,8 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<Tind>()), \
GatherBlockQuantized<T1, Tind>);

REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t);
Expand Down
21 changes: 17 additions & 4 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
If `zero_points` is not provided, 0 is the zero point.
If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
5. For uint8 data, the `gather_axis` must be 0.
)DOC";

ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized)
Expand Down Expand Up @@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
.Input(2, "scales", "quantization scale", "T2")
.Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional)
.Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2")
.TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.")
.TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.")
.TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.")
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
Expand Down Expand Up @@ -3637,21 +3638,30 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
gather_axis = (gather_axis + r) % r;
quantize_axis = (quantize_axis + r) % r;

if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) {
fail_shape_inference("gather_axis must be 0, for uint8 data");
}

if (scales_shape.dim_size() != r) {
fail_shape_inference("scales must have the same rank as data");
}

uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1;
for (int i = 0; i < r; ++i) {
if (!data_shape.dim(i).has_dim_value() ||
!scales_shape.dim(i).has_dim_value() ||
(i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
(i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
(i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) {
fail_shape_inference("data shape and scales shape do not match");
}
}

// validate zero point shape
if (ctx.hasInput(3)) {
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) {
fail_type_inference("zero_points are not supported for uint8_t data type");
}

if (!hasInputShape(ctx, 3)) {
fail_shape_inference("zero_points shape must be known");
}
Expand All @@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
}
for (int i = 0; i < out_rank; ++i) {
// For uint8_t data type the last dimension needs to be expanded back to actual dimension,
// because the data 2 int4s are stored packed in a single uint8_t.
auto last_dimension_components = (i == out_rank - 1) ? components : 1;
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() =
(i < gather_axis)
? data_shape.dim(i)
: (i >= gather_axis && i < gather_axis + q)
? indices_shape.dim(i - gather_axis)
: data_shape.dim(i - q + 1);
: data_shape.dim(i - q + 1) * last_dimension_components;
}
});

Expand Down
Loading
Loading