Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for uint8_t as data type for GatherBlockQuantized #24239

Merged
merged 8 commits into from
Apr 4, 2025
Merged
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized);
Expand Down Expand Up @@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized)>,
Expand Down
47 changes: 39 additions & 8 deletions onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@
namespace onnxruntime {
namespace contrib {

namespace {
template <typename T1>
int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) {
return static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
}

template <>
int32_t GetDataElement<uint8_t>(const uint8_t* data_ptr, int64_t data_idx) {
const uint8_t data_val_u8 = static_cast<const uint8_t>(data_ptr[data_idx >> 1]);
// Weights are stored as (nibble2)(nibble1) in uint8_t.
auto data_val = static_cast<int32_t>((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F));
return data_val;
}
} // namespace

template <typename T1, typename Tind>
class GatherBlockQuantized : public OpKernel {
public:
Expand Down Expand Up @@ -98,6 +113,12 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
for (int64_t i = p.gather_axis + 1; i < static_cast<int64_t>(data_rank); ++i)
shape.push_back(data_shape[narrow<size_t>(i)]);

// When data is stored as uint8_t, each element has two int4 values.
// The shape in the onnx model reflects that by having the last dimension be half the number of values.
// Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
// However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
uint32_t components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
shape[shape.size() - 1] = shape.back() * components;
p.output_tensor = context->Output(0, TensorShape(std::move(shape)));

// validate quantization parameters
Expand All @@ -106,7 +127,7 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
"data and scales must have the same rank.");
for (size_t i = 0; i < data_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(i == static_cast<size_t>(p.quantize_axis)
? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i]
? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i]
: data_shape[i] == scales_shape[i],
"data and scales do not match shapes.");
}
Expand Down Expand Up @@ -165,16 +186,24 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
int64_t output_idx = output_idx_base;
int64_t data_idx = data_idx_base;
for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) {
auto data_val = static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
auto data_val = GetDataElement(data_ptr, data_idx);

int64_t x = data_idx / quantize_full_block;
int64_t y = data_idx % quantize_full_block / quantize_N;
int64_t z = data_idx % quantize_N;
int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
auto scale_val = static_cast<float>(scales_ptr[scale_idx]);
auto zp_val = static_cast<int32_t>(zero_points_ptr
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
: 0);
int32_t zp_val;
if constexpr (std::is_same_v<T1, uint8_t>) {
// The default zero point for uint8 weights as stored by MatMulNBits op is 8.
// Both zero_points and data are of the same type T1 - when T1 is uint8_t, the
// zero_point can hence be accessed via simple indexing of zero_points_ptr.
zp_val = static_cast<int32_t>(zero_points_ptr ? zero_points_ptr[scale_idx] : 8);
} else {
zp_val = static_cast<int32_t>(zero_points_ptr
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
: 0);
}

output_ptr[output_idx] = static_cast<T2>(static_cast<float>(data_val - zp_val) * scale_val);
}
Expand Down Expand Up @@ -205,7 +234,7 @@ template <typename T1, typename Tind>
Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));

auto components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
const auto& data_shape = p.data_tensor->Shape();
// re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
// re-shape the indices tensor to [gather_N]
Expand All @@ -215,7 +244,7 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
// 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N],
// 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :],
// 4> pick the element from the block: value_i = data_blk[blk_ele_i]
const int64_t gather_block = data_shape.SizeFromDimension(SafeInt<size_t>(p.gather_axis) + 1);
const int64_t gather_block = data_shape.SizeFromDimension(SafeInt<size_t>(p.gather_axis) + 1) * components;
const int64_t gather_axis_dim = data_shape[narrow<size_t>(p.gather_axis)];
const int64_t gather_M = data_shape.SizeToDimension(narrow<size_t>(p.gather_axis));
const int64_t gather_N = p.indices_tensor->Shape().Size();
Expand All @@ -229,7 +258,7 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
// data_i % (quantize_axis_dim * quantize_N) / quantize_N,
// data_i % quantize_N)
// 4> get scale index: (x, y / block_size_, z)
const int64_t quantize_axis_dim = data_shape[narrow<size_t>(p.quantize_axis)];
const int64_t quantize_axis_dim = data_shape[narrow<size_t>(p.quantize_axis)] * components;
const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt<size_t>(p.quantize_axis) + 1);

concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
Expand Down Expand Up @@ -273,6 +302,8 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<Tind>()), \
GatherBlockQuantized<T1, Tind>);

REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t);
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3602,7 +3602,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
.Input(2, "scales", "quantization scale", "T2")
.Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional)
.Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2")
.TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.")
.TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.")
.TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.")
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
Expand Down Expand Up @@ -3641,10 +3641,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
fail_shape_inference("scales must have the same rank as data");
}

uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1;
for (int i = 0; i < r; ++i) {
if (!data_shape.dim(i).has_dim_value() ||
!scales_shape.dim(i).has_dim_value() ||
(i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
(i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
(i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) {
fail_shape_inference("data shape and scales shape do not match");
}
Expand Down Expand Up @@ -3680,7 +3681,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
? data_shape.dim(i)
: (i >= gather_axis && i < gather_axis + q)
? indices_shape.dim(i - gather_axis)
: data_shape.dim(i - q + 1);
: data_shape.dim(i - q + 1) * components;
}
});

Expand Down
Loading