@@ -40,14 +40,15 @@ class test_col_major_1 {
4040 static constexpr size_t sg_k = 1024 / 1 ;
4141 static constexpr size_t dequant_s = 128 ;
4242 // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
43- static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
43+ // static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
44+ static constexpr quant_mode quant_mode = quant_mode::INT4_ASYM_ZERO_NO_DEGRAD;
4445
4546 static constexpr size_t local_kslicing = 1 ;
4647 static constexpr size_t global_kslicing = 1 ;
4748 static constexpr mem_layout layout_a = mem_layout::row_major;
4849 static constexpr mem_layout layout_b = mem_layout::col_major;
4950 static constexpr mma_engine mma_eng = mma_engine::fpu;
50- static constexpr gpu_arch arch = gpu_arch::XeHpc ;
51+ static constexpr gpu_arch arch = gpu_arch::XeHpg ;
5152 using data_type_a = fp16;
5253 using data_type_b = int4x8;
5354 using data_type_c = fp16;
@@ -131,7 +132,9 @@ std::vector<fp16> convert_int4(
131132 data_type_zero_pt zero_pt) {
132133 std::vector<fp16> dequant_fp16 (sizeof (data_type_b) * 2 );
133134
134- int8_t zero_pt_i8 = zero_pt & 0xf ;
135+ int8_t zero_pt_i8;
136+ if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
137+ zero_pt_i8 = zero_pt & 0xf ;
135138 for (uint32_t i = 0 ; i < dequant_fp16.size (); i++) {
136139 int8_t dequant_8bit = data_b & 0xf ;
137140 if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
@@ -173,15 +176,17 @@ std::vector<data_type_acc_in> dequantize_weight(
173176 for (uint32_t j = 0 ; j < width; j += step) {
174177 int start_b_in = i * width + j;
175178 int start_scale_in = start_b_in / step;
176- int start_zero_pt_in =
177- (j / step) * (matrix_n / pack_radio) + i / pack_radio;
179+ int start_zero_pt_in = quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD
180+ ? (j / step) * matrix_n + i
181+ : (j / step) * (matrix_n / pack_radio) + i / pack_radio;
178182 int start_out =
179183 layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio;
184+ data_type_zero_pt zp_value = zero_pt[start_zero_pt_in];
185+ if constexpr (quant_mode != quant_mode::INT4_ASYM_ZERO_NO_DEGRAD)
186+ zp_value = zp_value >> (4 * (i % pack_radio));
180187 for (uint32_t jj = 0 ; jj < step; jj++) {
181188 std::vector<fp16> dequant_fp16 = convert_int4<quant_mode>(
182- b[start_b_in + jj],
183- scale[start_scale_in],
184- zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio)));
189+ b[start_b_in + jj], scale[start_scale_in], zp_value);
185190 for (uint32_t jjj = 0 ; jjj < dequant_fp16.size (); jjj++) {
186191 b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj];
187192 }
@@ -502,7 +507,9 @@ void dequantize_gemv_run(int iter) {
502507 Acc_d,
503508 Cnt_d,
504509 epilogue_args);
505- } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
510+ } else if constexpr (
511+ compute_policy::quant_mode == quant_mode::S4_ASYM ||
512+ compute_policy::quant_mode == quant_mode::INT4_ASYM_ZERO_NO_DEGRAD) {
506513 gemm_arg =
507514 typename gemm_op_t ::template arguments_t <compute_policy::quant_mode>(
508515 matrix_m,
0 commit comments