Skip to content

Commit 7348112

Browse files
drprajapmshahneo
andauthored
[test] Fix test cases with unsupported shaped buffer sizes (#847)
[test] Fix test cases that does xegpu.load_nd on unsupported shaped buffer In several test cases, we use load_nd on unsupported shaped buffer. This PR fixes one of those test cases. Co-authored-by: Md Abdullah Shahneous Bari <[email protected]>
1 parent 390bfb0 commit 7348112

13 files changed

+510
-501
lines changed

include/imex/ExecutionEngine/ImexRunnerUtils.h

+5
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ _mlir_ciface_fillResource1DRandomF16(UnrankedMemRefType<f16> *ptr,
7272
const float lower, const float upper,
7373
const bool genInt);
7474

75+
extern "C" IMEX_RUNNERUTILS_EXPORT void
76+
_mlir_ciface_fillResource1DRandomF32(UnrankedMemRefType<float> *ptr,
77+
const float lower, const float upper,
78+
const bool genInt);
79+
7580
extern "C" IMEX_RUNNERUTILS_EXPORT void
7681
_mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *m);
7782
extern "C" IMEX_RUNNERUTILS_EXPORT void

test/Integration/Dialect/XeGPU/dynamic_memref.vc.mlir

+19-25
Original file line numberDiff line numberDiff line change
@@ -7,55 +7,49 @@
77
// RUN: --entry-point-result=void \
88
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
99
module @gemm attributes {gpu.container_module} {
10-
func.func @test(%A : memref<8x16xf16>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
10+
func.func @test(%A : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
1111
%c1 = arith.constant 1 : index
12-
%memref_0 = gpu.alloc host_shared () : memref<8x16xf16>
13-
memref.copy %A, %memref_0 : memref<8x16xf16> to memref<8x16xf16>
12+
%memref_0 = gpu.alloc host_shared () : memref<8x16xf32>
13+
memref.copy %A, %memref_0 : memref<8x16xf32> to memref<8x16xf32>
1414
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
15-
%memref_0_cast = memref.cast %memref_0 : memref<8x16xf16> to memref<?x?xf16>
15+
%memref_0_cast = memref.cast %memref_0 : memref<8x16xf32> to memref<?x?xf32>
1616
%memref_1_cast = memref.cast %memref_1 : memref<8x16xf32> to memref<?x?xf32>
17-
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref<?x?xf16>, %memref_1_cast : memref<?x?xf32>)
18-
gpu.dealloc %memref_0 : memref<8x16xf16>
17+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref_0_cast : memref<?x?xf32>, %memref_1_cast : memref<?x?xf32>)
18+
gpu.dealloc %memref_0 : memref<8x16xf32>
1919
return %memref_1 : memref<8x16xf32>
2020
}
2121
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
22-
gpu.func @test_kernel(%arg0 : memref<?x?xf16>, %arg1: memref<?x?xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
22+
gpu.func @test_kernel(%arg0 : memref<?x?xf32>, %arg1: memref<?x?xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
2323
%c1 = arith.constant 1 : index
2424
%c8 = arith.constant 8 : index
2525
%c16 = arith.constant 16 : index
26-
%1 = xegpu.create_nd_tdesc %arg0[0, 0], [%c8, %c16], [%c16, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
27-
%2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
28-
%3 = vector.shape_cast %2 : vector<8x16xf16> to vector<128xf16>
29-
%5 = arith.extf %3 : vector<128xf16> to vector<128xf32>
30-
%4 = vector.shape_cast %5 : vector<128xf32> to vector<8x16xf32>
26+
%1 = xegpu.create_nd_tdesc %arg0[0, 0], [%c8, %c16], [%c16, %c1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
27+
%2 = xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
3128
%6 = xegpu.create_nd_tdesc %arg1[0, 0], [%c8, %c16], [%c16, %c1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
32-
xegpu.store_nd %4, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
29+
xegpu.store_nd %2, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
3330
gpu.return
3431
}
3532
}
3633
func.func @main() attributes {llvm.emit_c_interface} {
37-
%A = memref.alloc() : memref<8x16xf16>
38-
%A_random = memref.cast %A : memref<8x16xf16> to memref<*xf16>
34+
%A = memref.alloc() : memref<8x16xf32>
35+
%A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32>
3936
%c_gen_int = arith.constant 0 : i1
4037
%cf_lower = arith.constant -0.5 : f32
4138
%cf_upper = arith.constant 0.5 : f32
4239

43-
call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
40+
call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> ()
4441

45-
%B = call @test(%A) : (memref<8x16xf16>) -> memref<8x16xf32>
42+
%B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32>
4643
%B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32>
47-
%A_cast = memref.cast %A : memref<8x16xf16> to memref<*xf16>
48-
// call @printMemrefF16(%A_cast) : (memref<*xf16>) -> ()
44+
%A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32>
4945
// call @printMemrefF32(%B_cast) : (memref<*xf32>) -> ()
5046
// CHECK: [ALLCLOSE: TRUE]
51-
call @printAllcloseF16(%A_cast, %B_cast) : (memref<*xf16>, memref<*xf32>) -> ()
47+
call @printAllcloseF32(%A_cast, %B_cast) : (memref<*xf32>, memref<*xf32>) -> ()
5248

53-
memref.dealloc %A : memref<8x16xf16>
49+
memref.dealloc %A : memref<8x16xf32>
5450
return
5551
}
56-
func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
5752
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
58-
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
59-
func.func private @fillResource1DF16(memref<*xf16>, f32) attributes {llvm.emit_c_interface}
60-
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
53+
func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface}
54+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
6155
}

test/Integration/Dialect/XeGPU/exp_f32.vc.mlir

+32-57
Original file line numberDiff line numberDiff line change
@@ -7,63 +7,51 @@
77
// RUN: --entry-point-result=void \
88
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
99
module @gemm attributes {gpu.container_module} {
10-
func.func @test(%A: memref<8x16xf16>, %B: memref<16x16xf16> ) -> (memref<8x16xf32>, memref<8x16xf32>) attributes {llvm.emit_c_interface} {
10+
func.func @test(%A: memref<8x16xf32>) -> (memref<8x16xf32>, memref<8x16xf32>) attributes {llvm.emit_c_interface} {
1111
%c1 = arith.constant 1 : index
12-
%memref = gpu.alloc host_shared () : memref<8x16xf16>
13-
%memref_1 = gpu.alloc host_shared () : memref<16x16xf16>
14-
memref.copy %A, %memref : memref<8x16xf16> to memref<8x16xf16>
15-
memref.copy %B, %memref_1 : memref<16x16xf16> to memref<16x16xf16>
12+
%memref = gpu.alloc host_shared () : memref<8x16xf32>
13+
memref.copy %A, %memref : memref<8x16xf32> to memref<8x16xf32>
14+
1615
%memref_2 = gpu.alloc host_shared () : memref<8x16xf32>
1716
%memref_3 = gpu.alloc host_shared () : memref<8x16xf32>
18-
gpu.launch_func @module0::@test_exp_larger_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_1 : memref<16x16xf16>, %memref_2 : memref<8x16xf32>)
19-
gpu.launch_func @module1::@test_exp_generic_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf16>, %memref_1 : memref<16x16xf16>, %memref_3 : memref<8x16xf32>)
20-
gpu.dealloc %memref : memref<8x16xf16>
21-
gpu.dealloc %memref_1 : memref<16x16xf16>
17+
gpu.launch_func @module0::@test_exp_larger_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_2 : memref<8x16xf32>)
18+
gpu.launch_func @module1::@test_exp_generic_vec blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_3 : memref<8x16xf32>)
19+
gpu.dealloc %memref : memref<8x16xf32>
2220
return %memref_2, %memref_3 : memref<8x16xf32>, memref<8x16xf32>
2321
}
2422

2523
gpu.module @module0 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
26-
gpu.func @test_exp_larger_vec(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
24+
gpu.func @test_exp_larger_vec(%A: memref<8x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
2725
%c0 = arith.constant 0 : index
2826
%c16 = arith.constant 16 : index
2927
// load A tile
30-
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
31-
%val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
32-
// load B tile
33-
%b_tile0 = xegpu.create_nd_tdesc %B [%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
34-
%val2 = xegpu.load_nd %b_tile0 { packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
35-
// do DPAS
36-
%val4 = xegpu.dpas %val0, %val2 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
28+
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
29+
%val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
3730
// take exp
38-
%t6 = math.exp %val4 : vector<8x16xf32>
31+
%t6 = math.exp %val0 : vector<8x16xf32>
3932
// store
4033
%out_tile = xegpu.create_nd_tdesc %Out [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
4134
xegpu.store_nd %t6, %out_tile : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
4235
gpu.return
4336
}
4437
}
4538
gpu.module @module1 attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
46-
gpu.func @test_exp_generic_vec(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
39+
gpu.func @test_exp_generic_vec(%A: memref<8x16xf32>, %Out: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
4740
%c0 = arith.constant 0 : index
4841
%c16 = arith.constant 16 : index
4942
// load A tile
50-
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
51-
%val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
52-
// load B tile
53-
%b_tile0 = xegpu.create_nd_tdesc %B [%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
54-
%val2 = xegpu.load_nd %b_tile0 {packed} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
55-
// do DPAS
56-
%val4 = xegpu.dpas %val0, %val2 : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
57-
// extract dpas out into 16xf32 vectors
58-
%cst1 = arith.constant dense<1.4426950408889634> : vector<128xf32>
59-
%v0 = vector.extract %val4[0] : vector<16xf32> from vector<8x16xf32>
60-
%v1 = vector.extract %val4[1] : vector<16xf32> from vector<8x16xf32>
61-
%v2 = vector.extract %val4[2] : vector<16xf32> from vector<8x16xf32>
62-
%v3 = vector.extract %val4[3] : vector<16xf32> from vector<8x16xf32>
63-
%v4 = vector.extract %val4[4] : vector<16xf32> from vector<8x16xf32>
64-
%v5 = vector.extract %val4[5] : vector<16xf32> from vector<8x16xf32>
65-
%v6 = vector.extract %val4[6] : vector<16xf32> from vector<8x16xf32>
66-
%v7 = vector.extract %val4[7] : vector<16xf32> from vector<8x16xf32>
43+
%a_tile0 = xegpu.create_nd_tdesc %A [%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
44+
%val0 = xegpu.load_nd %a_tile0 : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
45+
46+
// extract the loaded vector into 16xf32 vectors
47+
%v0 = vector.extract %val0[0] : vector<16xf32> from vector<8x16xf32>
48+
%v1 = vector.extract %val0[1] : vector<16xf32> from vector<8x16xf32>
49+
%v2 = vector.extract %val0[2] : vector<16xf32> from vector<8x16xf32>
50+
%v3 = vector.extract %val0[3] : vector<16xf32> from vector<8x16xf32>
51+
%v4 = vector.extract %val0[4] : vector<16xf32> from vector<8x16xf32>
52+
%v5 = vector.extract %val0[5] : vector<16xf32> from vector<8x16xf32>
53+
%v6 = vector.extract %val0[6] : vector<16xf32> from vector<8x16xf32>
54+
%v7 = vector.extract %val0[7] : vector<16xf32> from vector<8x16xf32>
6755
// do generic size exp
6856
%v0_exp = math.exp %v0 : vector<16xf32>
6957
%v1_exp = math.exp %v1 : vector<16xf32>
@@ -104,31 +92,19 @@ module @gemm attributes {gpu.container_module} {
10492
%rand_lower = arith.constant -1.0 : f32
10593
%rand_upper = arith.constant 1.0 : f32
10694
%gen_int = arith.constant 0 : i1
107-
%A = memref.alloc() : memref<8x16xf16>
108-
%B = memref.alloc() : memref<16x16xf16>
95+
%A = memref.alloc() : memref<8x16xf32>
10996
%Out_cpu = memref.alloc() : memref<8x16xf32>
110-
%A_random = memref.cast %A : memref<8x16xf16> to memref<*xf16>
111-
%B_random = memref.cast %B : memref<16x16xf16> to memref<*xf16>
112-
call @fillResource1DRandomF16(%A_random, %rand_lower, %rand_upper, %gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
113-
call @fillResource1DRandomF16(%B_random, %rand_lower, %rand_upper, %gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
97+
%A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32>
98+
call @fillResource1DRandomF32(%A_random, %rand_lower, %rand_upper, %gen_int) : (memref<*xf32>, f32, f32, i1) -> ()
11499
// run GPU version
115-
%Out_gpu_large, %Out_gpu_generic = call @test(%A, %B) : (memref<8x16xf16>, memref<16x16xf16>) -> (memref<8x16xf32>, memref<8x16xf32>)
100+
%Out_gpu_large, %Out_gpu_generic = call @test(%A) : (memref<8x16xf32>) -> (memref<8x16xf32>, memref<8x16xf32>)
116101
%Out_gpu_generic_cast = memref.cast %Out_gpu_generic : memref<8x16xf32> to memref<*xf32>
117102
%Out_gpu_large_cast = memref.cast %Out_gpu_large : memref<8x16xf32> to memref<*xf32>
118103
// run CPU version
119104
scf.for %i = %c0 to %c8 step %c1 {
120105
scf.for %j = %c0 to %c16 step %c1 {
121-
%v0_init = arith.constant 0.0 : f32
122-
%result:1 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init) -> f32 {
123-
%a0 = memref.load %A[%i, %k] : memref<8x16xf16>
124-
%b0 = memref.load %B[%k, %j] : memref<16x16xf16>
125-
%a0_f32 = arith.extf %a0 : f16 to f32
126-
%b0_f32 = arith.extf %b0 : f16 to f32
127-
%t0 = arith.mulf %a0_f32, %b0_f32 : f32
128-
%v0_new = arith.addf %v0, %t0 : f32
129-
scf.yield %v0_new : f32
130-
}
131-
%vexp = math.exp %result#0: f32
106+
%a0 = memref.load %A[%i, %j] : memref<8x16xf32>
107+
%vexp = math.exp %a0: f32
132108
memref.store %vexp, %Out_cpu[%i, %j] : memref<8x16xf32>
133109
}
134110
}
@@ -141,15 +117,14 @@ module @gemm attributes {gpu.container_module} {
141117
call @printAllcloseF32(%Out_gpu_generic_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
142118
call @printAllcloseF32(%Out_gpu_large_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
143119
// dealloc
144-
memref.dealloc %A : memref<8x16xf16>
145-
memref.dealloc %B : memref<16x16xf16>
120+
memref.dealloc %A : memref<8x16xf32>
146121
memref.dealloc %Out_cpu : memref<8x16xf32>
147122
// gpu dealloc
148123
gpu.dealloc %Out_gpu_generic : memref<8x16xf32>
149124
gpu.dealloc %Out_gpu_large : memref<8x16xf32>
150125
return
151126
}
152127
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
153-
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
128+
func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface}
154129
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
155130
}

test/Integration/Dialect/XeGPU/fmax_f32.vc.mlir

+4-3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ module @gemm attributes {gpu.container_module} {
5050
%c1 = arith.constant 1 : index
5151
%c8 = arith.constant 8 : index
5252
%c16 = arith.constant 16 : index
53+
5354
%A = memref.alloc() : memref<8x32xf16>
5455
%B = memref.alloc() : memref<16x32xf16>
5556
%Out_cpu = memref.alloc() : memref<8x16xf32>
@@ -72,9 +73,9 @@ module @gemm attributes {gpu.container_module} {
7273
%v0_init = arith.constant 0.0 : f32
7374
%v1_init = arith.constant 0.0 : f32
7475
%result:2 = scf.for %k = %c0 to %c16 step %c1 iter_args(%v0 = %v0_init, %v1 = %v1_init) -> (f32, f32){
75-
%a0 = memref.load %A[%i, %k] : memref<8x32xf16>
7676
%1 = arith.addi %k, %c16 : index
7777
%2 = arith.addi %j, %c16 : index
78+
%a0 = memref.load %A[%i, %k] : memref<8x32xf16>
7879
%a1 = memref.load %A[%i, %1] : memref<8x32xf16>
7980
%b0 = memref.load %B[%k, %j] : memref<16x32xf16>
8081
%b1 = memref.load %B[%k, %2] : memref<16x32xf16>
@@ -94,8 +95,8 @@ module @gemm attributes {gpu.container_module} {
9495
}
9596
%Out_cpu_cast = memref.cast %Out_cpu : memref<8x16xf32> to memref<*xf32>
9697
// print GPU and CPU outs
97-
// call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> ()
98-
// call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> ()
98+
call @printMemrefF32(%Out_cpu_cast) : (memref<*xf32>) -> ()
99+
call @printMemrefF32(%Out_gpu_cast) : (memref<*xf32>) -> ()
99100
// CHECK: [ALLCLOSE: TRUE]
100101
call @printAllcloseF32(%Out_gpu_cast, %Out_cpu_cast) : (memref<*xf32>, memref<*xf32>) -> ()
101102
// dealloc

0 commit comments

Comments
 (0)