diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f582abca34706..0308b5c79c508 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
- If `zero_points` is not provided, 0 is the zero point.
+ If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
+ 5. For uint8 data, the `gather_axis` must be 0.
#### Version
@@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T1 : tensor(int4), tensor(uint4)
+- T1 : tensor(int4), tensor(uint4), tensor(uint8)
- Constrain quantized types.
- T2 : tensor(float), tensor(float16), tensor(bfloat16)
- Constrain dequantized types.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 60d9e8e747eeb..a20333e2340c4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -515,7 +515,7 @@ Do not modify directly.*
|FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedGemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
-|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)|
+|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)|
|GatherND|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 345b5e793a764..1a737f3a9d251 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized);
+class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized);
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized);
@@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
index 5935663f114a3..b83164d806ffc 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc
@@ -16,6 +16,21 @@
namespace onnxruntime {
namespace contrib {
+namespace {
+template
+int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) {
+ return static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1)));
+}
+
+template <>
+int32_t GetDataElement(const uint8_t* data_ptr, int64_t data_idx) {
+ const uint8_t data_val_u8 = data_ptr[data_idx >> 1];
+ // Weights are stored as (nibble2)(nibble1) in uint8_t.
+ auto data_val = static_cast((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F));
+ return data_val;
+}
+} // namespace
+
template
class GatherBlockQuantized : public OpKernel {
public:
@@ -98,6 +113,12 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex
for (int64_t i = p.gather_axis + 1; i < static_cast(data_rank); ++i)
shape.push_back(data_shape[narrow(i)]);
+ // When data is stored as uint8_t, each element has two int4 values.
+ // The shape in the onnx model reflects that by having the last dimension be half the number of values.
+ // Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
+ // However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
+ uint32_t components = (std::is_same_v) ? 2 : 1;
+ shape[shape.size() - 1] = shape.back() * components;
p.output_tensor = context->Output(0, TensorShape(std::move(shape)));
// validate quantization parameters
@@ -106,7 +127,7 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex
"data and scales must have the same rank.");
for (size_t i = 0; i < data_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(i == static_cast(p.quantize_axis)
- ? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i]
+ ? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i]
: data_shape[i] == scales_shape[i],
"data and scales do not match shapes.");
}
@@ -165,16 +186,22 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr,
int64_t output_idx = output_idx_base;
int64_t data_idx = data_idx_base;
for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) {
- auto data_val = static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1)));
+ auto data_val = GetDataElement(data_ptr, data_idx);
int64_t x = data_idx / quantize_full_block;
int64_t y = data_idx % quantize_full_block / quantize_N;
int64_t z = data_idx % quantize_N;
int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
auto scale_val = static_cast(scales_ptr[scale_idx]);
- auto zp_val = static_cast(zero_points_ptr
- ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1))
- : 0);
+ int32_t zp_val;
+ if constexpr (std::is_same_v) {
+ // The default zero point for uint8 weights as stored by MatMulNBits op is 8.
+ zp_val = 8;
+ } else {
+ zp_val = static_cast(zero_points_ptr
+ ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1))
+ : 0);
+ }
output_ptr[output_idx] = static_cast(static_cast(data_val - zp_val) * scale_val);
}
@@ -205,7 +232,7 @@ template
Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
-
+ auto components = (std::is_same_v) ? 2 : 1;
const auto& data_shape = p.data_tensor->Shape();
// re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
// re-shape the indices tensor to [gather_N]
@@ -215,7 +242,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
// 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N],
// 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :],
// 4> pick the element from the block: value_i = data_blk[blk_ele_i]
- const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1);
+ const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1) * components;
const int64_t gather_axis_dim = data_shape[narrow(p.gather_axis)];
const int64_t gather_M = data_shape.SizeToDimension(narrow(p.gather_axis));
const int64_t gather_N = p.indices_tensor->Shape().Size();
@@ -229,7 +256,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
// data_i % (quantize_axis_dim * quantize_N) / quantize_N,
// data_i % quantize_N)
// 4> get scale index: (x, y / block_size_, z)
- const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)];
+ const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)] * components;
const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt(p.quantize_axis) + 1);
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
@@ -273,6 +300,8 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const {
.TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \
GatherBlockQuantized);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t);
+REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t);
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t);
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t);
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 7b4a45ce8aa0f..d87688a62040c 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
- If `zero_points` is not provided, 0 is the zero point.
+ If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
to dequantize the output.
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
+ 5. For uint8 data, the `gather_axis` must be 0.
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized)
@@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
.Input(2, "scales", "quantization scale", "T2")
.Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional)
.Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2")
- .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.")
+ .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.")
.TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.")
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
@@ -3637,14 +3638,19 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
gather_axis = (gather_axis + r) % r;
quantize_axis = (quantize_axis + r) % r;
+ if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) {
+ fail_shape_inference("gather_axis must be 0, for uint8 data");
+ }
+
if (scales_shape.dim_size() != r) {
fail_shape_inference("scales must have the same rank as data");
}
+ uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1;
for (int i = 0; i < r; ++i) {
if (!data_shape.dim(i).has_dim_value() ||
!scales_shape.dim(i).has_dim_value() ||
- (i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
+ (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
(i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) {
fail_shape_inference("data shape and scales shape do not match");
}
@@ -3652,6 +3658,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
// validate zero point shape
if (ctx.hasInput(3)) {
+ if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) {
+ fail_type_inference("zero_points are not supported for uint8_t data type");
+ }
+
if (!hasInputShape(ctx, 3)) {
fail_shape_inference("zero_points shape must be known");
}
@@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
}
for (int i = 0; i < out_rank; ++i) {
+ // For uint8_t data type the last dimension needs to be expanded back to actual dimension,
+ // because the data 2 int4s are stored packed in a single uint8_t.
+ auto last_dimension_components = (i == out_rank - 1) ? components : 1;
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() =
(i < gather_axis)
? data_shape.dim(i)
: (i >= gather_axis && i < gather_axis + q)
? indices_shape.dim(i - gather_axis)
- : data_shape.dim(i - q + 1);
+ : data_shape.dim(i - q + 1) * last_dimension_components;
}
});
diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
index c4536fc56a22f..0dfe194e893e2 100644
--- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
+++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc
@@ -15,6 +15,27 @@
namespace onnxruntime {
namespace test {
+// When uint8_t data type is used GatherBlockQuantize applies MatMulNBit's conventions for storing the data.
+// That is when no zero points are specified a default zero point of 8 is used. This convertor hence
+// compensates for that by adding 8 to the data values, so that the outputs match the results that
+// we be seen with non uint8_t data types.
+template
+void PackDataForUint8TypeIfNecessary(std::vector& data, std::vector& data_shape) {
+ if (!std::is_same_v) {
+ return;
+ }
+ // For uint8_t, we need to pack each pair of values (after adding 8) into a single uint8_t
+ std::vector packed_data;
+ for (size_t i = 0; i < data.size(); i += 2) {
+ int low_nibble = (data[i] + 8) & 0xF;
+ int high_nibble = ((i + 1) < data.size()) ? ((data[i + 1] + 8) & 0xF) : 0;
+ int packed = (high_nibble << 4) | low_nibble;
+ packed_data.push_back(packed);
+ }
+ data = packed_data;
+ data_shape[data_shape.size() - 1] = (data_shape[data_shape.size() - 1] + 1) / 2;
+}
+
// Combinations: types, gather_axis, quantize_axis, block_size, indices, scale shape vs data shape
template
void RunGatherBlockQuantized(const std::vector& data,
@@ -96,6 +117,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis,
4, 5, 6, 7,
-4, -3, -2, -1};
std::vector data_shape = {2, 3, 4};
+ PackDataForUint8TypeIfNecessary(data, data_shape);
std::vector indices = {1};
std::vector indices_shape = {1};
std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f};
@@ -123,7 +145,6 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis,
TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) {
Test_Fail_WithZeroPoints(0, 2, 16);
- Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
@@ -134,21 +155,70 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) {
Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
Test_Fail_WithZeroPoints(0, 2, 16);
+ Test_Fail_WithZeroPoints(0, 2, 16);
+}
+
+template
+void Test_Fail_WithoutZeroPoints(int64_t gather_axis,
+ int64_t quantize_axis,
+ int64_t block_size) {
+ std::vector data = {-8, -7, -6, -5,
+ -4, -3, -2, -1,
+ 0, 1, 2, 3,
+ 4, 5, 6, 7,
+ 4, 5, 6, 7,
+ -4, -3, -2, -1};
+ std::vector data_shape = {2, 3, 4};
+ PackDataForUint8TypeIfNecessary(data, data_shape);
+ std::vector indices = {1};
+ std::vector indices_shape = {1};
+ std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f};
+ std::vector scales_shape = {2, 3, 1};
+ std::vector output = {8.f, 10.f, 12.f, 14.f,
+ 3.f, 4.f, 5.f, 6.f,
+ -6.f, -4.f, -2.f, 0.f};
+ std::vector output_shape = {1, 3, 4};
+
+ RunGatherBlockQuantized(ToType(data),
+ data_shape,
+ ToType(indices),
+ indices_shape,
+ ToType(scales),
+ scales_shape,
+ {},
+ gather_axis,
+ quantize_axis,
+ block_size,
+ ToType(output),
+ output_shape,
+ OpTester::ExpectResult::kExpectFailure);
+}
+
+TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) {
+ // T1 uint8_t with zero points is not yet supported.
+ Test_Fail_WithZeroPoints(0, 2, 16);
+ Test_Fail_WithZeroPoints(0, 2, 16);
+ // Gather on axis other than 0 is not supported with uint8_t
+ Test_Fail_WithoutZeroPoints(1, 2, 16);
+ Test_Fail_WithoutZeroPoints(1, 2, 16);
}
TEST(GatherBlockQuantizedOpTest, InvalidBlockSize) {
Test_Fail_WithZeroPoints(0, 2, 8);
Test_Fail_WithZeroPoints(0, 2, 17);
+ Test_Fail_WithZeroPoints(0, 2, 17);
}
TEST(GatherBlockQuantizedOpTest, InvalidGatherAxis) {
Test_Fail_WithZeroPoints(3, 2, 16);
Test_Fail_WithZeroPoints(-4, 2, 16);
+ Test_Fail_WithZeroPoints(-4, 2, 16);
}
TEST(GatherBlockQuantizedOpTest, InvalidQuantizeAxis) {
Test_Fail_WithZeroPoints(0, 3, 16);
Test_Fail_WithZeroPoints(0, -4, 16);
+ Test_Fail_WithZeroPoints(0, -4, 16);
}
template
@@ -160,6 +230,7 @@ void Test_ShapeMismatch_WithZeroPoints() {
4, 5, 6, 7,
-4, -3, -2, -1};
std::vector data_shape = {2, 3, 4};
+ PackDataForUint8TypeIfNecessary(data, data_shape);
std::vector indices = {1};
std::vector indices_shape = {1};
std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f};
@@ -188,6 +259,7 @@ void Test_ShapeMismatch_WithZeroPoints() {
TEST(GatherBlockQuantizedOpTest, ShapeMismatch) {
Test_ShapeMismatch_WithZeroPoints();
Test_ShapeMismatch_WithZeroPoints();
+ Test_ShapeMismatch_WithZeroPoints();
}
template
@@ -199,6 +271,7 @@ void Test_InvalidIndices_WithZeroPoints() {
4, 5, 6, 7,
-4, -3, -2, -1};
std::vector data_shape = {2, 3, 4};
+ PackDataForUint8TypeIfNecessary(data, data_shape);
std::vector indices = {2};
std::vector indices_shape = {1};
std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f};
@@ -227,6 +300,7 @@ void Test_InvalidIndices_WithZeroPoints() {
TEST(GatherBlockQuantizedOpTest, InvalidIndices) {
Test_InvalidIndices_WithZeroPoints();
Test_InvalidIndices_WithZeroPoints();
+ Test_InvalidIndices_WithZeroPoints();
}
template
@@ -298,6 +372,7 @@ void Test_GatherAxis0_NoZeroPoints() {
4, 5, 6, 7,
-4, -3, -2, -1};
std::vector data_shape = {2, 3, 4};
+ PackDataForUint8TypeIfNecessary(data, data_shape);
std::vector indices = {1};
std::vector indices_shape = {1};
std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f};
@@ -340,6 +415,10 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) {
Test_GatherAxis0_NoZeroPoints();
Test_GatherAxis0_NoZeroPoints();
Test_GatherAxis0_NoZeroPoints();
+ Test_GatherAxis0_NoZeroPoints();
+ Test_GatherAxis0_NoZeroPoints();
+ Test_GatherAxis0_NoZeroPoints();
+ Test_GatherAxis0_NoZeroPoints();
}
template