Skip to content

Commit 8ac2fb6

Browse files
authored
refine XeGPU definition (#849)
- add verification for create_tdesc regarding to the chunk size and total size - update load_gather and store_scatter definition to reveal the transpose effect
1 parent eb8c81a commit 8ac2fb6

File tree

7 files changed

+243
-16
lines changed

7 files changed

+243
-16
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
From 8a734652353bdd85b9cc7d2426e7395404372d72 Mon Sep 17 00:00:00 2001
2+
From: Chao Chen <[email protected]>
3+
Date: Wed, 28 Aug 2024 23:57:49 +0000
4+
Subject: [PATCH] refine the XeGPU definition - add verification for
5+
scattered tensordesc regarding to chunk size and total size - refine
6+
load_gather and store_scatter to reveal transpose effect
7+
8+
---
9+
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 40 +++++++++++------
10+
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 1 +
11+
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 44 ++++++++++++++++---
12+
3 files changed, 65 insertions(+), 20 deletions(-)
13+
14+
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
15+
index a3922bbad2b3..3e0c6f243fd4 100644
16+
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
17+
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
18+
@@ -413,24 +413,28 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
19+
implying each element in the array corresponds to a work-item (SIMT lane)
20+
in the subgroup.
21+
22+
+ The first dimension of the result TensorDesc corresponds to work-items, so it should
23+
+ match the dimension of offsets. It may also has a second dimension corresponding to
24+
+ the chunk_size if the chunk size is larger than 1.
25+
+
26+
Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
27+
```mlir
28+
%a = memref.alloc() : memref<1024xf32>
29+
- %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32, chunk_size_per_lane = 1>
30+
+ %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
31+
```
32+
33+
Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
34+
It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
35+
```mlir
36+
%0 = memref.alloc() : memref<1024xf32>
37+
- %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size_per_lane = 8>
38+
+ %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>
39+
```
40+
41+
Example 3. It is similar to Example 2, but there is some overlaps among workitems.
42+
It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
43+
```mlir
44+
%0 = memref.alloc() : memref<1024xf32>
45+
- %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size_per_lane = 8>>
46+
+ %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>>
47+
```
48+
}];
49+
50+
@@ -500,28 +504,31 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
51+
52+
let description = [{ It (aka. load) load data per each work-item. The output
53+
describes the data being loaded at the subgroup level, so its size is
54+
- consistent with the number of work-items in a subgroup. When `chunk_size_per_lane`
55+
- attribute is larger than 1 in TensorDesc, the output vector will be 2D vector,
56+
- with dim-1 correspoding to the chunk size.
57+
+ consistent with the number of work-items in a subgroup. When the chunk size
58+
+ is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
59+
+ to work-items, and dim-0 corresponding to the chunk_size loaded by each work-item.
60+
+ Specially, there is a transpose effect on the result (as compared to the TensorDesc)
61+
+ due to the hardware implementation. Therefore, a transpose attribute is introduced
62+
+ on purpose, making sure users are aware of this implicit transformation.
63+
64+
The mask operand masks out memory access so that it is safe to pass out-of-boundary
65+
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
66+
67+
Example:
68+
```mlir
69+
- %2 = xegpu.load %1, %0 {transpose = [1, 0],
70+
+ %2 = xegpu.load %1, %0 {transpose,
71+
l1_hint = #xegpu.cache_hint<cached>,
72+
l2_hint = #xegpu.cache_hint<uncached>,
73+
l3_hint = #xegpu.cache_hint<uncached>}
74+
- : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
75+
- -> vector<16xf32>
76+
+ : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_scope=global>>,
77+
+ vector<16xi1> -> vector<16xf32>
78+
```
79+
80+
}];
81+
82+
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
83+
XeGPU_MaskType: $mask,
84+
- OptionalAttr<DenseI64ArrayAttr>: $transpose,
85+
+ OptionalAttr<UnitAttr>: $transpose,
86+
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
87+
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
88+
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
89+
@@ -553,11 +560,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
90+
let hasVerifier = 1;
91+
}
92+
93+
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDesc"]>,
94+
- AllElementTypesMatch<["value", "TensorDesc"]>]> {
95+
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
96+
+ AllElementTypesMatch<["value", "TensorDesc"]>]> {
97+
let summary = "store data to scattered memory locations.";
98+
- let description = [{ It (aka. store) stores data to scattered memory locations.
99+
- It has similar semantic to `load_gather`.
100+
+ let description = [{ It (aka. store) stores data to scattered memory locations. The value is
101+
+ typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
102+
+ a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
103+
+ and the dim-0 of the value corresponds to the chunk_size stored per lane. So `store_scatter`
104+
+ has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
105+
+ introduced on purpose, making sure users are aware of this implicit transformation.
106+
107+
Example:
108+
```mlir
109+
@@ -572,6 +583,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDe
110+
XeGPU_ValueType: $value,
111+
XeGPU_TensorDesc: $TensorDesc,
112+
XeGPU_MaskType: $mask,
113+
+ OptionalAttr<UnitAttr>: $transpose,
114+
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
115+
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
116+
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
117+
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
118+
index 0eab601bbaac..555c232ff1f0 100644
119+
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
120+
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
121+
@@ -57,6 +57,7 @@ ScatterTensorDescAttr ScatterTensorDescAttr::get(mlir::MLIRContext *context,
122+
//===----------------------------------------------------------------------===//
123+
// XeGPU_TensorDescType
124+
//===----------------------------------------------------------------------===//
125+
+
126+
mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
127+
llvm::SmallVector<int64_t> shape;
128+
mlir::Type elementType;
129+
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
130+
index c9e399a7149f..b35a639540aa 100644
131+
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
132+
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
133+
@@ -305,6 +305,26 @@ LogicalResult CreateDescOp::verify() {
134+
135+
auto chunkSize = tdescTy.getChunkSize();
136+
137+
+ // check chunk_size
138+
+ llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8, 16, 32, 64, 128, 256};
139+
+ if (!llvm::is_contained(supportedChunkSizes, chunkSize))
140+
+ return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, 8, 16, 32, 64, 128, or 256.");
141+
+
142+
+ // check total size
143+
+ auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
144+
+ auto bitsPerLane = elemBits * chunkSize;
145+
+ if (bitsPerLane % 32) {
146+
+ // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
147+
+ // For 32-bit data, the hardware can support larger larger chunk size. So
148+
+ // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
149+
+ // But this requires the total size is 32 bit aligned to make the optimization work.
150+
+ return emitOpError("access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
151+
+ }
152+
+
153+
+ auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
154+
+ if (elemBits * tdescTy.getNumElements() > lscConstraints)
155+
+ return emitOpError("total access size (simd_lanes * chunk_size * sizeof(elemTy)) is upto 512 bytes.");
156+
+
157+
SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
158+
if (chunkSize != 1)
159+
shape.push_back(chunkSize);
160+
@@ -370,14 +390,13 @@ LogicalResult LoadGatherOp::verify() {
161+
if (tdescShape[0] != maskShape[0])
162+
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
163+
164+
- if (getTransposeAttr()) {
165+
- auto trans = getTranspose().value();
166+
- if (tdescShape.size() < trans.size())
167+
- emitWarning("Invalid transpose attr. It is ignored.");
168+
- else
169+
- transpose(trans, tdescShape);
170+
+ if (tdescTy.getRank() == 2) {
171+
+ if (!getTransposeAttr())
172+
+ return emitOpError("load_gather has to be transposed.");
173+
+ transpose({1, 0}, tdescShape);
174+
}
175+
176+
+
177+
if (valueShape != tdescShape)
178+
return emitOpError("Unexpected result shape")
179+
<< "(Expected shape: " << makeString(tdescShape)
180+
@@ -404,11 +423,24 @@ LogicalResult StoreScatterOp::verify() {
181+
return emitOpError("invlid l3_hint: ") << getL3HintAttr();
182+
183+
auto maskTy = getMaskType();
184+
+ auto valueTy = getValueType();
185+
auto maskShape = getShapeOf(maskTy);
186+
auto tdescShape = getShapeOf(tdescTy);
187+
+ auto valueShape = getShapeOf(valueTy);
188+
if (tdescShape[0] != maskShape[0])
189+
return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
190+
191+
+ if (tdescTy.getRank() == 2) {
192+
+ if (!getTransposeAttr())
193+
+ return emitOpError("load_gather has to be transposed.");
194+
+ transpose({1, 0}, tdescShape);
195+
+ }
196+
+
197+
+ if (valueShape != tdescShape)
198+
+ return emitOpError("Unexpected value shape")
199+
+ << "(Expected shape: " << makeString(tdescShape)
200+
+ << ", Given shape: " << makeString(valueShape) << ").\n";
201+
+
202+
return success();
203+
}
204+
//===----------------------------------------------------------------------===//
205+
--
206+
2.34.1

test/Conversion/XeGPUToVC/loadgather.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ module @gemm attributes {gpu.container_module} {
3232

3333
// CHECK: %[[OLD:.*]] = arith.constant dense<0> : vector<16xi32>
3434
// CHECK: %[[LOAD_RES:.*]] = func.call @llvm.genx.raw.send2.v16i32.v16i1.v16i64({{.*}}, %[[MASK]], {{.*}}, %[[IN_PAYLOAD]], %[[OLD]]) : (i8, i8, vector<16xi1>, i8, i8, i8, i32, i32, vector<16xindex>, vector<16xi32>) -> vector<16xi32>
35-
%loaded = xegpu.load %tdesc_in, %mask : !xegpu.tensor_desc<16x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<16x2xf16>
35+
%loaded = xegpu.load %tdesc_in, %mask {transpose} : !xegpu.tensor_desc<16x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<2x16xf16>
3636
// CHECK: %[[POST_OP_ELEMENT_TYPE_CAST:.*]] = vector.bitcast %[[LOAD_RES]] : vector<16xi32> to vector<32xf16>
3737

3838
// CHECK: %[[PRE_OP_ELEMENT_TYPE_CAST:.*]] = vector.bitcast %[[POST_OP_ELEMENT_TYPE_CAST]] : vector<32xf16> to vector<16xi32>
3939
// CHECK: func.call @llvm.genx.raw.sends2.noresult.v16i1.v16i64.v16i32({{.*}}, %[[MASK]], {{.*}}, %[[OUT_PAYLOAD]], %[[PRE_OP_ELEMENT_TYPE_CAST]]) : (i8, i8, vector<16xi1>, i8, i8, i8, i32, i32, vector<16xindex>, vector<16xi32>) -> ()
40-
xegpu.store %loaded, %tdesc_out, %mask : vector<16x2xf16>, !xegpu.tensor_desc<16x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
40+
xegpu.store %loaded, %tdesc_out, %mask {transpose} : vector<2x16xf16>, !xegpu.tensor_desc<16x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
4141

4242
gpu.return
4343
}

test/Conversion/XeGPUToVC/loadgather_dpas.mlir

+5-8
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,8 @@ module @gemm attributes {gpu.container_module} {
2323

2424
// CHECK: %[[OLD:.*]] = arith.constant dense<0> : vector<64xi32>
2525
// CHECK: %[[LOAD_RES:.*]] = func.call @llvm.genx.raw.send2.v64i32.v16i1.v16i64({{.*}}, %[[MASK]], {{.*}}, %[[IN_PAYLOAD]], %[[OLD]]) : (i8, i8, vector<16xi1>, i8, i8, i8, i32, i32, vector<16xindex>, vector<64xi32>) -> vector<64xi32>
26-
%3 = xegpu.load %0, %mask : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1> -> vector<16x8xf16>
27-
2826
// CHECK: %[[LOADA_v128f16:.*]] = vector.bitcast %[[LOAD_RES]] : vector<64xi32> to vector<128xf16>
29-
%66 = vector.shape_cast %3: vector<16x8xf16> to vector<128xf16>
30-
%6 = vector.shape_cast %66: vector<128xf16> to vector<8x16xf16>
27+
%3 = xegpu.load %0, %mask {transpose} : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1> -> vector<8x16xf16>
3128

3229
// CHECK: %[[B_STRUCT:.*]]= arith.constant dense<0> : vector<4xi64>
3330
// CHECK: %[[B_BASEPTR:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<16x16xf16> -> index
@@ -46,7 +43,7 @@ module @gemm attributes {gpu.container_module} {
4643

4744
// CHECK: %[[LOADA_v64i32:.*]] = vector.bitcast %[[LOADA_v128f16]] : vector<128xf16> to vector<64xi32>
4845
// CHECK: %[[C_ACC_v128f32:.*]] = func.call @llvm.genx.dpas.nosrc0.v128f32.v128i32.v64i32(%{{.*}}, %[[LOADA_v64i32]], %{{.*}}) : (vector<128xi32>, vector<64xi32>, i32) -> vector<128xf32>
49-
%5 = xegpu.dpas %6, %4 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
46+
%5 = xegpu.dpas %3, %4 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
5047

5148
// CHECK: %[[OUT_OFFSET:.*]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
5249
%offsets2 = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
@@ -59,11 +56,11 @@ module @gemm attributes {gpu.container_module} {
5956
// CHECK: %[[OUT_ELEMENTWISE_OFFSET:.*]] = arith.muli %[[OUT_ELEMENT_BYTEWIDTH]], %[[OUT_OFFSET]] : vector<16xindex>
6057
// CHECK: %[[OUT_PAYLOAD:.*]] = arith.addi %[[OUT_PAYLOAD_BASEPTR_SHUFFLED]], %[[OUT_ELEMENTWISE_OFFSET]] : vector<16xindex>
6158
%2 = xegpu.create_tdesc %arg2, %offsets2 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
62-
%7 = vector.shape_cast %5: vector<8x16xf32> to vector<128xf32>
63-
%8 = vector.shape_cast %7: vector<128xf32> to vector<16x8xf32>
59+
// %7 = vector.shape_cast %5: vector<8x16xf32> to vector<128xf32>
60+
// %8 = vector.shape_cast %7: vector<128xf32> to vector<16x8xf32>
6461

6562
// CHECK: func.call @llvm.genx.raw.sends2.noresult.v16i1.v16i64.v128f32({{.*}}, %[[MASK]], {{.*}}, %[[OUT_PAYLOAD]], %[[C_ACC_v128f32]]) : (i8, i8, vector<16xi1>, i8, i8, i8, i32, i32, vector<16xindex>, vector<128xf32>) -> ()
66-
xegpu.store %8, %2, %mask : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
63+
xegpu.store %5, %2, %mask {transpose} : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>
6764

6865
gpu.return
6966
}

test/Dialect/XeGPU/IR/invalid_vc.mlir

+24
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,27 @@ func.func @test_load_gather(%src: ui64, %offsets : vector<16xindex>) {
6767
: !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1> -> vector<8x8x4xf16>
6868
return
6969
}
70+
71+
// -----
72+
func.func @test_create_tdesc_oversized(%src: ui64, %offsets : vector<16xindex>) {
73+
// expected-error@+1 {{total access size (simd_lanes * chunk_size * sizeof(elemTy)) is upto 512 bytes}}
74+
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex>
75+
-> !xegpu.tensor_desc<16x16xf32, #xegpu.scatter_tdesc_attr<chunk_size = 16>>
76+
return
77+
}
78+
79+
// -----
80+
func.func @test_create_tdesc_invalid_chunk_size(%src: ui64, %offsets : vector<16xindex>) {
81+
// expected-error@+1 {{Invalid chunk_size. Supported values are 1, 2, 3, 4, 8, 16, 32, 64, 128, or 256.}}
82+
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex>
83+
-> !xegpu.tensor_desc<16x7xf32, #xegpu.scatter_tdesc_attr<chunk_size = 7>>
84+
return
85+
}
86+
87+
// -----
88+
func.func @test_create_tdesc_unaligned(%src: ui64, %offsets : vector<16xindex>) {
89+
// expected-error@+1 {{access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned}}
90+
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex>
91+
-> !xegpu.tensor_desc<16x3xf16, #xegpu.scatter_tdesc_attr<chunk_size = 3>>
92+
return
93+
}

test/Dialect/XeGPU/IR/load_gather_vc.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ func.func @test_load_gather_vc_2(%src: ui64, %offsets : vector<16xindex>) {
2828
%1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex>
2929
-> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
3030

31-
//CHECK: {{.*}} = xegpu.load {{.*}}, {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}>
31+
//CHECK: {{.*}} = xegpu.load {{.*}}, {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
3232
//CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1> -> vector<8x16xf32>
33-
%2 = xegpu.load %1, %0 {transpose = array<i64: 1, 0>, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}
33+
%2 = xegpu.load %1, %0 {transpose, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}
3434
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1> -> vector<8x16xf32>
3535
return
3636
}

test/Integration/Dialect/XeGPU/loadgather_chunk_size_f32.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ module @gemm attributes {gpu.container_module} {
3636
%mask = arith.constant dense<[1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0]> : vector<16xi1>
3737
%tdesc_in = xegpu.create_tdesc %in, %offsets : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
3838
%tdesc_out = xegpu.create_tdesc %out, %offsets : memref<?xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
39-
%loaded = xegpu.load %tdesc_in, %mask : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<16x2xf32>
40-
xegpu.store %loaded, %tdesc_out, %mask : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
39+
%loaded = xegpu.load %tdesc_in, %mask {transpose} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<2x16xf32>
40+
xegpu.store %loaded, %tdesc_out, %mask {transpose} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
4141
gpu.return
4242
}
4343
}

test/Integration/Dialect/XeGPU/loadgather_chunk_size_i32.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ module @gemm attributes {gpu.container_module} {
3636
%mask = arith.constant dense<[1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0]> : vector<16xi1>
3737
%tdesc_in = xegpu.create_tdesc %in, %offsets : memref<?xi32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
3838
%tdesc_out = xegpu.create_tdesc %out, %offsets : memref<?xi32>, vector<16xindex> -> !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
39-
%loaded = xegpu.load %tdesc_in, %mask : !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<16x2xi32>
40-
xegpu.store %loaded, %tdesc_out, %mask : vector<16x2xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
39+
%loaded = xegpu.load %tdesc_in, %mask {transpose} : !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1> -> vector<2x16xi32>
40+
xegpu.store %loaded, %tdesc_out, %mask {transpose} : vector<2x16xi32>, !xegpu.tensor_desc<16x2xi32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<16xi1>
4141
gpu.return
4242
}
4343
}

0 commit comments

Comments
 (0)