diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f582abca34706..0308b5c79c508 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point. + If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. + 5. For uint8 data, the `gather_axis` must be 0. #### Version @@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int4), tensor(uint4)
+
T1 : tensor(int4), tensor(uint4), tensor(uint8)
Constrain quantized types.
T2 : tensor(float), tensor(float16), tensor(bfloat16)
Constrain dequantized types.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 60d9e8e747eeb..a20333e2340c4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -515,7 +515,7 @@ Do not modify directly.* |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedGemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| +|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| |GatherND|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 345b5e793a764..1a737f3a9d251 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized); @@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc index 5935663f114a3..b83164d806ffc 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc @@ -16,6 +16,21 @@ namespace onnxruntime { namespace contrib { +namespace { +template +int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) { + return static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1))); +} + +template <> +int32_t GetDataElement(const uint8_t* data_ptr, int64_t data_idx) { + const uint8_t data_val_u8 = data_ptr[data_idx >> 1]; + // Weights are stored as (nibble2)(nibble1) in uint8_t. + auto data_val = static_cast((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F)); + return data_val; +} +} // namespace + template class GatherBlockQuantized : public OpKernel { public: @@ -98,6 +113,12 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex for (int64_t i = p.gather_axis + 1; i < static_cast(data_rank); ++i) shape.push_back(data_shape[narrow(i)]); + // When data is stored as uint8_t, each element has two int4 values. + // The shape in the onnx model reflects that by having the last dimension be half the number of values. + // Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536. + // However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here. + uint32_t components = (std::is_same_v) ? 2 : 1; + shape[shape.size() - 1] = shape.back() * components; p.output_tensor = context->Output(0, TensorShape(std::move(shape))); // validate quantization parameters @@ -106,7 +127,7 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex "data and scales must have the same rank."); for (size_t i = 0; i < data_shape.NumDimensions(); ++i) { ORT_RETURN_IF_NOT(i == static_cast(p.quantize_axis) - ? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i] + ? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i] : data_shape[i] == scales_shape[i], "data and scales do not match shapes."); } @@ -165,16 +186,22 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, int64_t output_idx = output_idx_base; int64_t data_idx = data_idx_base; for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) { - auto data_val = static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1))); + auto data_val = GetDataElement(data_ptr, data_idx); int64_t x = data_idx / quantize_full_block; int64_t y = data_idx % quantize_full_block / quantize_N; int64_t z = data_idx % quantize_N; int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z; auto scale_val = static_cast(scales_ptr[scale_idx]); - auto zp_val = static_cast(zero_points_ptr - ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1)) - : 0); + int32_t zp_val; + if constexpr (std::is_same_v) { + // The default zero point for uint8 weights as stored by MatMulNBits op is 8. + zp_val = 8; + } else { + zp_val = static_cast(zero_points_ptr + ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1)) + : 0); + } output_ptr[output_idx] = static_cast(static_cast(data_val - zp_val) * scale_val); } @@ -205,7 +232,7 @@ template Status GatherBlockQuantized::Compute(OpKernelContext* context) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); - + auto components = (std::is_same_v) ? 2 : 1; const auto& data_shape = p.data_tensor->Shape(); // re-shape the data tensor to [gather_M, gather_axis_dim, gather_block] // re-shape the indices tensor to [gather_N] @@ -215,7 +242,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { // 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N], // 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :], // 4> pick the element from the block: value_i = data_blk[blk_ele_i] - const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1); + const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1) * components; const int64_t gather_axis_dim = data_shape[narrow(p.gather_axis)]; const int64_t gather_M = data_shape.SizeToDimension(narrow(p.gather_axis)); const int64_t gather_N = p.indices_tensor->Shape().Size(); @@ -229,7 +256,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { // data_i % (quantize_axis_dim * quantize_N) / quantize_N, // data_i % quantize_N) // 4> get scale index: (x, y / block_size_, z) - const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)]; + const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)] * components; const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt(p.quantize_axis) + 1); concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); @@ -273,6 +300,8 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherBlockQuantized); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t); REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t); REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t); REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 7b4a45ce8aa0f..d87688a62040c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point. + If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. + 5. For uint8 data, the `gather_axis` must be 0. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized) @@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h .Input(2, "scales", "quantization scale", "T2") .Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional) .Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2") - .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.") + .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.") .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { @@ -3637,14 +3638,19 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h gather_axis = (gather_axis + r) % r; quantize_axis = (quantize_axis + r) % r; + if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) { + fail_shape_inference("gather_axis must be 0, for uint8 data"); + } + if (scales_shape.dim_size() != r) { fail_shape_inference("scales must have the same rank as data"); } + uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1; for (int i = 0; i < r; ++i) { if (!data_shape.dim(i).has_dim_value() || !scales_shape.dim(i).has_dim_value() || - (i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || + (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) { fail_shape_inference("data shape and scales shape do not match"); } @@ -3652,6 +3658,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h // validate zero point shape if (ctx.hasInput(3)) { + if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) { + fail_type_inference("zero_points are not supported for uint8_t data type"); + } + if (!hasInputShape(ctx, 3)) { fail_shape_inference("zero_points shape must be known"); } @@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); } for (int i = 0; i < out_rank; ++i) { + // For uint8_t data type the last dimension needs to be expanded back to actual dimension, + // because the data 2 int4s are stored packed in a single uint8_t. + auto last_dimension_components = (i == out_rank - 1) ? components : 1; *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() = (i < gather_axis) ? data_shape.dim(i) : (i >= gather_axis && i < gather_axis + q) ? indices_shape.dim(i - gather_axis) - : data_shape.dim(i - q + 1); + : data_shape.dim(i - q + 1) * last_dimension_components; } }); diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index c4536fc56a22f..0dfe194e893e2 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -15,6 +15,27 @@ namespace onnxruntime { namespace test { +// When uint8_t data type is used GatherBlockQuantize applies MatMulNBit's conventions for storing the data. +// That is when no zero points are specified a default zero point of 8 is used. This convertor hence +// compensates for that by adding 8 to the data values, so that the outputs match the results that +// we be seen with non uint8_t data types. +template +void PackDataForUint8TypeIfNecessary(std::vector& data, std::vector& data_shape) { + if (!std::is_same_v) { + return; + } + // For uint8_t, we need to pack each pair of values (after adding 8) into a single uint8_t + std::vector packed_data; + for (size_t i = 0; i < data.size(); i += 2) { + int low_nibble = (data[i] + 8) & 0xF; + int high_nibble = ((i + 1) < data.size()) ? ((data[i + 1] + 8) & 0xF) : 0; + int packed = (high_nibble << 4) | low_nibble; + packed_data.push_back(packed); + } + data = packed_data; + data_shape[data_shape.size() - 1] = (data_shape[data_shape.size() - 1] + 1) / 2; +} + // Combinations: types, gather_axis, quantize_axis, block_size, indices, scale shape vs data shape template void RunGatherBlockQuantized(const std::vector& data, @@ -96,6 +117,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -123,7 +145,6 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); - Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); @@ -134,21 +155,70 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); + Test_Fail_WithZeroPoints(0, 2, 16); +} + +template +void Test_Fail_WithoutZeroPoints(int64_t gather_axis, + int64_t quantize_axis, + int64_t block_size) { + std::vector data = {-8, -7, -6, -5, + -4, -3, -2, -1, + 0, 1, 2, 3, + 4, 5, 6, 7, + 4, 5, 6, 7, + -4, -3, -2, -1}; + std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + std::vector output = {8.f, 10.f, 12.f, 14.f, + 3.f, 4.f, 5.f, 6.f, + -6.f, -4.f, -2.f, 0.f}; + std::vector output_shape = {1, 3, 4}; + + RunGatherBlockQuantized(ToType(data), + data_shape, + ToType(indices), + indices_shape, + ToType(scales), + scales_shape, + {}, + gather_axis, + quantize_axis, + block_size, + ToType(output), + output_shape, + OpTester::ExpectResult::kExpectFailure); +} + +TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) { + // T1 uint8_t with zero points is not yet supported. + Test_Fail_WithZeroPoints(0, 2, 16); + Test_Fail_WithZeroPoints(0, 2, 16); + // Gather on axis other than 0 is not supported with uint8_t + Test_Fail_WithoutZeroPoints(1, 2, 16); + Test_Fail_WithoutZeroPoints(1, 2, 16); } TEST(GatherBlockQuantizedOpTest, InvalidBlockSize) { Test_Fail_WithZeroPoints(0, 2, 8); Test_Fail_WithZeroPoints(0, 2, 17); + Test_Fail_WithZeroPoints(0, 2, 17); } TEST(GatherBlockQuantizedOpTest, InvalidGatherAxis) { Test_Fail_WithZeroPoints(3, 2, 16); Test_Fail_WithZeroPoints(-4, 2, 16); + Test_Fail_WithZeroPoints(-4, 2, 16); } TEST(GatherBlockQuantizedOpTest, InvalidQuantizeAxis) { Test_Fail_WithZeroPoints(0, 3, 16); Test_Fail_WithZeroPoints(0, -4, 16); + Test_Fail_WithZeroPoints(0, -4, 16); } template @@ -160,6 +230,7 @@ void Test_ShapeMismatch_WithZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f}; @@ -188,6 +259,7 @@ void Test_ShapeMismatch_WithZeroPoints() { TEST(GatherBlockQuantizedOpTest, ShapeMismatch) { Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); + Test_ShapeMismatch_WithZeroPoints(); } template @@ -199,6 +271,7 @@ void Test_InvalidIndices_WithZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {2}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -227,6 +300,7 @@ void Test_InvalidIndices_WithZeroPoints() { TEST(GatherBlockQuantizedOpTest, InvalidIndices) { Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); + Test_InvalidIndices_WithZeroPoints(); } template @@ -298,6 +372,7 @@ void Test_GatherAxis0_NoZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -340,6 +415,10 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) { Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); } template