@@ -818,7 +818,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
818
818
%BB_DOT = ttg.local_load %BB : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_b >
819
819
%cst0 = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #dpas0 >
820
820
821
- // CHECK-COUNT-2: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
821
+ // CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
822
822
%D = tt.dot %AA_DOT , %BB_DOT , %cst0 : tensor <16 x16 xf16 , #dot_operand_a > * tensor <16 x16 xf16 , #dot_operand_b > -> tensor <16 x16 xf32 , #dpas0 >
823
823
824
824
tt.return
@@ -968,7 +968,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
968
968
%a_mat = ttg.local_load %a : !ttg.memdesc <128 x32 xf16 , #shared , #smem > -> tensor <128 x32 xf16 , #dot_operand_a >
969
969
%b_mat = ttg.local_load %b : !ttg.memdesc <32 x256 xf16 , #shared , #smem > -> tensor <32 x256 xf16 , #dot_operand_b >
970
970
971
- // CHECK-COUNT-128: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
971
+ // CHECK-COUNT-128: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
972
972
%28 = tt.dot %a_mat , %b_mat , %cst : tensor <128 x32 xf16 , #dot_operand_a > * tensor <32 x256 xf16 , #dot_operand_b > -> tensor <128 x256 xf32 , #dpas >
973
973
%38 = ttg.convert_layout %28 : tensor <128 x256 xf32 , #dpas > -> tensor <128 x256 xf32 , #blocked >
974
974
@@ -995,7 +995,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
995
995
%a_mat = ttg.local_load %a : !ttg.memdesc <32 x64 xf16 , #shared0 , #smem > -> tensor <32 x64 xf16 , #dot_operand_a >
996
996
%b_mat = ttg.local_load %b : !ttg.memdesc <64 x64 xf16 , #shared1 , #smem > -> tensor <64 x64 xf16 , #dot_operand_b >
997
997
998
- // CHECK-COUNT-16: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
998
+ // CHECK-COUNT-16: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32 ) -> vector<8xf32>
999
999
%28 = tt.dot %a_mat , %b_mat , %cst : tensor <32 x64 xf16 , #dot_operand_a > * tensor <64 x64 xf16 , #dot_operand_b > -> tensor <32 x64 xf32 , #dpas >
1000
1000
%38 = ttg.convert_layout %28 : tensor <32 x64 xf32 , #dpas > -> tensor <32 x64 xf32 , #blocked >
1001
1001
%30 = tt.splat %ptr : !tt.ptr <f32 > -> tensor <32 x1 x!tt.ptr <f32 >, #blocked >
@@ -1044,7 +1044,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
1044
1044
%a_mat = ttg.local_load %a : !ttg.memdesc <32 x16 xf32 , #shared , #smem > -> tensor <32 x16 xf32 , #dot_operand_a >
1045
1045
%b_mat = ttg.local_load %b : !ttg.memdesc <16 x32 xf32 , #shared , #smem > -> tensor <16 x32 xf32 , #dot_operand_b >
1046
1046
1047
- // CHECK-COUNT-2: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv8_fS_S_ (%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
1047
+ // CHECK-COUNT-2: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_fS_S_i (%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} ) {{.*}} : (i32, vector<8xf32>, vector<8xf32>, vector<8xf32>, i32 ) -> vector<8xf32>
1048
1048
%28 = tt.dot %a_mat , %b_mat , %cst , inputPrecision = tf32 : tensor <32 x16 xf32 , #dot_operand_a > * tensor <16 x32 xf32 , #dot_operand_b > -> tensor <32 x32 xf32 , #dpas >
1049
1049
%38 = ttg.convert_layout %28 : tensor <32 x32 xf32 , #dpas > -> tensor <32 x32 xf32 , #blocked >
1050
1050
0 commit comments