Skip to content

Commit 5c647e7

Browse files
authored
[XeGPU] Support loading 1x16xf16 (#852)
support loading 1x16xf16
1 parent 8ac2fb6 commit 5c647e7

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

lib/Conversion/XeGPUToVC/XeGPUToVC.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,9 @@ class LoadStorePrefetchNdToRawSendPattern : public OpConversionPattern<OpType> {
625625
auto execSize = createIntConstant(i8Type, 0);
626626
auto pred = createIntConstant(i1Type, 1);
627627
auto numSrc1 = createIntConstant(i8Type, 1);
628-
unsigned numDstVal = newType.getNumElements() / 16;
628+
// numDstVal: "Dest Length" is a value with the unit of a register
629+
unsigned numDstVal =
630+
(newType.getNumElements() + 16 - 1) / 16; // TODO: clarify 16
629631
if (rank == 1) {
630632
numDstVal *= 2;
631633
}
@@ -879,7 +881,7 @@ class GatherScatterToRawSend : public OpConversionPattern<OpType> {
879881
auto execSize = createIntConstant(i8Type, 4);
880882
auto pred = adaptor.getMask();
881883
auto numSrc1 = createIntConstant(i8Type, 2);
882-
unsigned numDstVal = newType.getNumElements() / 16;
884+
unsigned numDstVal = (newType.getNumElements() + 16 - 1) / 16;
883885
auto numDst = createIntConstant(i8Type, numDstVal);
884886
// 15 for ugm
885887
auto sfid = createIntConstant(i8Type, 15);
@@ -972,7 +974,7 @@ class AtomicToLsc : public OpConversionPattern<::mlir::xegpu::AtomicRMWOp> {
972974
auto immOffset = createIntConstant(i32Type, 0);
973975
unsigned dataSize = encodeDataum(vecType.getElementType());
974976
auto dataumSize = createIntConstant(i8Type, dataSize);
975-
unsigned numDstVal = newType.getNumElements() / 16;
977+
unsigned numDstVal = (newType.getNumElements() + 16 - 1) / 16;
976978
auto lscVecSize = 0;
977979
if (numDstVal <= 4) {
978980
lscVecSize = numDstVal;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
module @gemm attributes {gpu.container_module} {
10+
func.func @test(%arg0: memref<1x32xf16>) -> memref<1x32xf32> attributes {llvm.emit_c_interface} {
11+
%c1 = arith.constant 1 : index
12+
%memref = gpu.alloc host_shared () : memref<1x32xf16>
13+
memref.copy %arg0, %memref : memref<1x32xf16> to memref<1x32xf16>
14+
%memref_1 = gpu.alloc host_shared () : memref<1x32xf32>
15+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<1x32xf16>, %memref_1 : memref<1x32xf32>)
16+
gpu.dealloc %memref : memref<1x32xf16>
17+
return %memref_1 : memref<1x32xf32>
18+
}
19+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
20+
gpu.func @test_kernel(%arg0: memref<1x32xf16>, %arg1: memref<1x32xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
21+
%src_tdesc_0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<1x32xf16> -> !xegpu.tensor_desc<1x16xf16>
22+
%src_tdesc_1 = xegpu.create_nd_tdesc %arg0[0, 16] : memref<1x32xf16> -> !xegpu.tensor_desc<1x16xf16>
23+
24+
%src_loaded_0 = xegpu.load_nd %src_tdesc_0 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16>
25+
%src_loaded_1 = xegpu.load_nd %src_tdesc_1 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16>
26+
27+
%src_loaded_0_f32 = arith.extf %src_loaded_0: vector<1x16xf16> to vector<1x16xf32>
28+
%src_loaded_1_f32 = arith.extf %src_loaded_1: vector<1x16xf16> to vector<1x16xf32>
29+
30+
%dest_tdesc_0 = xegpu.create_nd_tdesc %arg1[0, 0] : memref<1x32xf32> -> !xegpu.tensor_desc<1x16xf32>
31+
%dest_tdesc_1 = xegpu.create_nd_tdesc %arg1[0, 16] : memref<1x32xf32> -> !xegpu.tensor_desc<1x16xf32>
32+
33+
xegpu.store_nd %src_loaded_0_f32, %dest_tdesc_0 : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
34+
xegpu.store_nd %src_loaded_1_f32, %dest_tdesc_1 : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
35+
36+
gpu.return
37+
}
38+
}
39+
func.func @main() attributes {llvm.emit_c_interface} {
40+
%A = memref.alloc() : memref<1x32xf16> // 1x32 to ensure surface pitch >= 64
41+
%A_random = memref.cast %A : memref<1x32xf16> to memref<*xf16>
42+
%c_gen_int = arith.constant 1 : i1
43+
%cf_lower = arith.constant -2.0 : f32
44+
%cf_upper = arith.constant 2.0 : f32
45+
call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
46+
47+
%B = call @test(%A) : (memref<1x32xf16>) -> memref<1x32xf32>
48+
%A_cast = memref.cast %A : memref<1x32xf16> to memref<*xf16>
49+
%B_cast = memref.cast %B : memref<1x32xf32> to memref<*xf32>
50+
// call @printMemrefF16(%A_cast) : (memref<*xf16>) -> ()
51+
// call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
52+
53+
// CHECK: [ALLCLOSE: TRUE]
54+
call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> ()
55+
return
56+
}
57+
func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
58+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
59+
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
60+
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
61+
}

0 commit comments

Comments
 (0)