Skip to content

Commit 296d8d0

Browse files
Change DPAS from OCL to SPV
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 465b0b6 commit 296d8d0

File tree

7 files changed

+126
-86
lines changed

7 files changed

+126
-86
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

+4-4
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
818818
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
819819
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #dpas0>
820820

821-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
821+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32>
822822
%D = tt.dot %AA_DOT, %BB_DOT, %cst0 : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #dpas0>
823823

824824
tt.return
@@ -968,7 +968,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
968968
%a_mat = ttg.local_load %a : !ttg.memdesc<128x32xf16, #shared, #smem> -> tensor<128x32xf16, #dot_operand_a>
969969
%b_mat = ttg.local_load %b : !ttg.memdesc<32x256xf16, #shared, #smem> -> tensor<32x256xf16, #dot_operand_b>
970970

971-
// CHECK-COUNT-128: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
971+
// CHECK-COUNT-128: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32>
972972
%28 = tt.dot %a_mat, %b_mat, %cst : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #dpas>
973973
%38 = ttg.convert_layout %28 : tensor<128x256xf32, #dpas> -> tensor<128x256xf32, #blocked>
974974

@@ -995,7 +995,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
995995
%a_mat = ttg.local_load %a : !ttg.memdesc<32x64xf16, #shared0, #smem> -> tensor<32x64xf16, #dot_operand_a>
996996
%b_mat = ttg.local_load %b : !ttg.memdesc<64x64xf16, #shared1, #smem> -> tensor<64x64xf16, #dot_operand_b>
997997

998-
// CHECK-COUNT-16: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
998+
// CHECK-COUNT-16: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32>
999999
%28 = tt.dot %a_mat, %b_mat, %cst : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #dpas>
10001000
%38 = ttg.convert_layout %28 : tensor<32x64xf32, #dpas> -> tensor<32x64xf32, #blocked>
10011001
%30 = tt.splat %ptr : !tt.ptr<f32> -> tensor<32x1x!tt.ptr<f32>, #blocked>
@@ -1044,7 +1044,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10441044
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
10451045
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
10461046

1047-
// CHECK-COUNT-2: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv8_fS_S_(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
1047+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_fS_S_i(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, vector<8xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
10481048
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
10491049
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
10501050

0 commit comments

Comments
 (0)