Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 8f0abc4

Browse files
committed
fix group_qkv
1 parent 93c8ad1 commit 8f0abc4

File tree

4 files changed

+31
-21
lines changed

4 files changed

+31
-21
lines changed

include/group/epilogue/impl/default_xe.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ class epilogue_t<
100100
subgroup::msg_type_v<mat_tile_desc, mem_desc_c_t>;
101101
using matC_payload_t = subgroup::
102102
mem_payload_t<mem_desc_c_t, mat_tile_desc, msg_type_c, arch_tag>;
103-
using matC_payload_t = subgroup::
104-
mem_payload_t<mem_desc_c_t, mat_tile_desc, msg_type_c, arch_tag>;
105103
update_sg_tile_tdesc(g, mem_desc_c);
106104
matC_t matC;
107105
matC_payload_t matC_payload(mem_desc_c);

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,10 @@ struct mem_payload_t<
443443
pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype);
444444
width_in_elems = mem_tdesc.shape.x;
445445
height_in_elems = mem_tdesc.shape.y;
446-
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
446+
payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes +
447+
mem_tdesc.shape.y * sizeof(dtype)
448+
: (mem_tdesc.shape.y - 1) * pitch_in_bytes +
449+
mem_tdesc.shape.x * sizeof(dtype);
447450
uint32_t offset_x = mem_tdesc.coord.x;
448451
uint32_t offset_y = mem_tdesc.coord.y;
449452
base_offset = mem_transpose
@@ -464,7 +467,10 @@ struct mem_payload_t<
464467
uint32_t offset_y = surface_offset_y;
465468
width_in_elems = surface_width;
466469
height_in_elems = surface_height;
467-
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
470+
payload_bytes = mem_transpose ? (surface_offset_x - 1) * pitch_in_bytes +
471+
surface_offset_y * sizeof(dtype)
472+
: (surface_offset_y - 1) * pitch_in_bytes +
473+
surface_offset_x * sizeof(dtype);
468474
base_offset = mem_transpose
469475
? offset_x * pitch_in_bytes + offset_y * sizeof(dtype)
470476
: offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
@@ -477,7 +483,11 @@ struct mem_payload_t<
477483
uint32_t offset_y = mem_tdesc.coord.y;
478484
width_in_elems = mem_tdesc.shape.x;
479485
height_in_elems = mem_tdesc.shape.y;
480-
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
486+
payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes +
487+
mem_tdesc.shape.y * sizeof(dtype)
488+
: (mem_tdesc.shape.y - 1) * pitch_in_bytes +
489+
mem_tdesc.shape.x * sizeof(dtype);
490+
481491
base_offset = mem_transpose
482492
? offset_x * pitch_in_bytes + offset_y * sizeof(dtype)
483493
: offset_y * pitch_in_bytes + offset_x * sizeof(dtype);
@@ -496,7 +506,9 @@ struct mem_payload_t<
496506
uint32_t offset_y = surface_offset_y;
497507
width_in_elems = surface_width;
498508
height_in_elems = surface_height;
499-
payload_bytes = width_in_elems * height_in_elems * sizeof(dtype);
509+
payload_bytes = mem_transpose
510+
? (surface_width - 1) * pitch_in_bytes + surface_height * sizeof(dtype)
511+
: (surface_height - 1) * pitch_in_bytes + surface_width * sizeof(dtype);
500512
base_offset = mem_transpose
501513
? offset_x * pitch_in_bytes + offset_y * sizeof(dtype)
502514
: offset_y * pitch_in_bytes + offset_x * sizeof(dtype);

include/subgroup/tile/impl/store_xe.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ tile_store(tile_t& tile, payload_t& payload) {
282282
using dtype = typename payload_t::dtype;
283283
static constexpr uint32_t store_len = tile_t::tile_elems;
284284
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
285-
286-
if (payload.base_offset <= payload.payload_bytes) {
285+
if (payload.base_offset + store_len * sizeof(dtype) <=
286+
payload.payload_bytes) {
287287
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
288288
static constexpr uint32_t max_store_vec_len =
289289
load_store_attr::max_store_vec_len;

tests/integration/gemv/int4/main.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class test_col_major_1 {
4747
static constexpr mem_layout layout_a = mem_layout::row_major;
4848
static constexpr mem_layout layout_b = mem_layout::col_major;
4949
static constexpr mma_engine mma_eng = mma_engine::fpu;
50-
static constexpr gpu_arch arch = gpu_arch::XeLpg;
50+
static constexpr gpu_arch arch = gpu_arch::XeHpc;
5151
using data_type_a = fp16;
5252
using data_type_b = int4x8;
5353
using data_type_c = fp16;
@@ -108,12 +108,12 @@ int gemm_result_validate(
108108
bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation");
109109

110110
#ifdef UT_DEBUG
111-
for (uint32_t i = 0; i < m; i++) {
112-
for (uint32_t j = 0; j < n; j++) {
113-
std::cout << float(sycl::half(C[i * n + j])) << " ";
114-
}
115-
std::cout << std::endl;
116-
}
111+
// for (uint32_t i = 0; i < m; i++) {
112+
// for (uint32_t j = 0; j < n; j++) {
113+
// std::cout << float(sycl::half(C[i * n + j])) << " ";
114+
// }
115+
// std::cout << std::endl;
116+
// }
117117
#endif
118118
std::cout << (!result ? "FAILED\n" : "PASSED\n");
119119
return result ? 0 : 1;
@@ -185,12 +185,12 @@ std::vector<data_type_acc_in> dequantize_weight(
185185
}
186186
}
187187
#ifdef UT_DEBUG
188-
for (uint32_t i = 0; i < matrix_n; i++) {
189-
for (uint32_t j = 0; j < matrix_k; j++) {
190-
std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
191-
}
192-
std::cout << std::endl;
193-
}
188+
// for (uint32_t i = 0; i < matrix_n; i++) {
189+
// for (uint32_t j = 0; j < matrix_k; j++) {
190+
// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
191+
// }
192+
// std::cout << std::endl;
193+
// }
194194
#endif
195195
return b_out;
196196
}

0 commit comments

Comments
 (0)