@@ -40,7 +40,8 @@ class test_col_major_1 {
4040 static constexpr size_t sg_k = 512 / sg_m;
4141 static constexpr size_t dequant_s = 128 ;
4242 // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43- static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43+ // static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44+ static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_FP_ZERO;
4445
4546 static constexpr size_t local_kslicing = 1 ;
4647 static constexpr size_t global_kslicing = 1 ;
@@ -131,13 +132,19 @@ std::vector<fp16> convert_int4(
131132 data_type_zero_pt zero_pt) {
132133 std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
133134
134- int8_t zero_pt_i8 = zero_pt & 0xf ;
135+ int8_t zero_pt_i8;
136+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
137+ zero_pt_i8 = zero_pt & 0xf ;
135138 for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
136139 int8_t dequant_8bit = data_b & 0xf ;
137140 if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
138141 dequant_fp16[i] = scale * (dequant_8bit - 8 );
139- } else {
142+ } else if constexpr (quant_mode == quant_mode::S4_ASYM) {
140143 dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8);
144+ } else if constexpr (quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
145+ dequant_fp16[i] = scale * (dequant_8bit - 8 ) + zero_pt;
146+ } else {
147+ assert (0 );
141148 }
142149 data_b = data_b >> 4 ;
143150 }
@@ -169,15 +176,17 @@ std::vector<data_type_acc_in> dequantize_weight(
169176 for (uint32_t j = 0 ; j < width; j += step) {
170177 int start_b_in = i * width + j;
171178 int start_scale_in = start_b_in / step;
172- int start_zero_pt_in =
173- (j / step) * (matrix_n / pack_radio) + i / pack_radio;
179+ int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_FP_ZERO
180+ ? (j / step) * matrix_n + i
181+ : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
174182 int start_out =
175183 layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184+ data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185+ if constexpr (quant_mode != quant_mode::INT4_ASYM_FP_ZERO)
186+ zp_value = zp_value >> (4 * (i % pack_radio));
176187 for (uint32_t jj = 0 ; jj < step; jj++) {
177188 std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
178- b[start_b_in + jj],
179- scale[start_scale_in],
180- zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
189+ b[start_b_in + jj], scale[start_scale_in], zp_value);
181190 for (uint32_t jjj = 0 ; jjj < dequant_fp16.size (); jjj++) {
182191 b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
183192 }
@@ -215,7 +224,10 @@ void dequantize_gemv_run(int iter) {
215224 using data_type_a = typename Test::data_type_a;
216225 using data_type_b = typename Test::data_type_b;
217226 using data_type_c = typename Test::data_type_c;
218- using data_type_zero_pt = data_type_b;
227+ using data_type_zero_pt = std::conditional_t <
228+ Test::quant_mode == quant_mode::INT4_ASYM_FP_ZERO,
229+ data_type_c,
230+ data_type_b>;
219231 using data_type_scale = fp16;
220232 using data_type_acc_in = fp16;
221233 using data_type_acc = float ;
@@ -225,7 +237,7 @@ void dequantize_gemv_run(int iter) {
225237 constexpr mem_layout layout_b = Test::layout_b;
226238
227239 constexpr size_t size_a = matrix_m * matrix_k;
228- constexpr size_t size_b = matrix_k * matrix_n / ( 2 * sizeof (data_type_b)) ;
240+ constexpr size_t size_b = matrix_k * matrix_n / 2 ;
229241
230242 constexpr size_t size_scale_k = matrix_k / dequant_s;
231243 constexpr size_t size_scale_n = matrix_n;
@@ -234,7 +246,9 @@ void dequantize_gemv_run(int iter) {
234246 constexpr size_t size_zero_pt_k = matrix_k / dequant_s;
235247 constexpr size_t size_zero_pt_n = matrix_n;
236248 constexpr size_t size_zero_pt =
237- size_zero_pt_k * size_zero_pt_n / (2 * sizeof (data_type_b));
249+ Test::quant_mode != quant_mode::INT4_ASYM_FP_ZERO
250+ ? size_zero_pt_k * size_zero_pt_n / 2
251+ : size_zero_pt_k * size_zero_pt_n;
238252
239253 constexpr size_t size_c = matrix_m * matrix_n;
240254 constexpr size_t size_bias = matrix_n;
@@ -405,16 +419,18 @@ void dequantize_gemv_run(int iter) {
405419 scale_h[i] = INFINITY;
406420 }
407421 for (unsigned i = 0 ; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) {
408- if constexpr (std::is_same_v<int4x2, data_type_b >) {
422+ if constexpr (std::is_same_v<int4x2, data_type_zero_pt >) {
409423 zero_pt_h[i] = random_uint8 ();
410424#ifdef UT_DEBUG
411425 zero_pt_h[i] = 0x12 << i;
412426#endif
413- } else if constexpr (std::is_same_v<int4x8, data_type_b >) {
427+ } else if constexpr (std::is_same_v<int4x8, data_type_zero_pt >) {
414428 zero_pt_h[i] = random_uint32 ();
415429#ifdef UT_DEBUG
416430 zero_pt_h[i] = 0x33333333 ;
417431#endif
432+ } else if constexpr (std::is_same_v<fp16, data_type_zero_pt>) {
433+ zero_pt_h[i] = random_float ();
418434 }
419435 }
420436
@@ -491,7 +507,9 @@ void dequantize_gemv_run(int iter) {
491507 Acc_d,
492508 Cnt_d,
493509 epilogue_args);
494- } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
510+ } else if constexpr (
511+ compute_policy::quant_mode == quant_mode::S4_ASYM ||
512+ compute_policy::quant_mode == quant_mode::INT4_ASYM_FP_ZERO) {
495513 gemm_arg =
496514 typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
497515 matrix_m,
0 commit comments