@@ -524,12 +524,12 @@ class MMHelper {
524524
525525 // E4M3
526526 else if constexpr (std::is_same_v<WeiT, e4m3_t >) {
527- int amx_rows = ( int )((K + 15 ) / 16 ) * 16 ;
528- int amx_cols = ( int )((N + 63 ) / 64 ) * 64 ;
529- if (!weight.isShadow ()) weight.Resize (amx_rows, amx_cols );
530- memset (weight.Data (), 0 , sizeof (e4m3_t ) * amx_rows * amx_cols );
527+ int packBlkSize = 32 ;
528+ size_t pack_size = xdnn_small_amx_sgemm_bf16f8bf16_packb_size (K, N, packBlkSize) ;
529+ if (!weight.isShadow ()) weight.Resize ((pack_size + N - 1 ) / N, N );
530+ memset (weight.Data (), 0 , sizeof (e4m3_t ) * pack_size );
531531 xdnn_small_amx_sgemm_bf16f8bf16_packb (trans, N, K, (const XDNN_E4M3 *)src.Data (), src.Stride (),
532- (XDNN_E4M3 *)weight.Data (), 64 );
532+ (XDNN_E4M3 *)weight.Data (), packBlkSize );
533533 }
534534 }
535535
@@ -691,7 +691,7 @@ class MMHelper {
691691
692692 // E4M3
693693 else if constexpr (std::is_same_v<WeiT, e4m3_t >) {
694- if (M <= 16 ) {
694+ if (true ) {
695695 assert (blockSize == 128 );
696696 if (lds == -1 ) lds = (K + 127 ) / 128 ;
697697 GEMMVERBOSE (" xdnn_gemm_bf16f8bf16_compute" ,
@@ -1509,7 +1509,7 @@ class MMHelper {
15091509
15101510 // E4M3
15111511 else if constexpr (std::is_same_v<WeiT, e4m3_t >) {
1512- if (M <= 16 ) {
1512+ if (true ) {
15131513 assert (blockSize == 128 );
15141514 if (lds == -1 ) lds = (K + 127 ) / 128 ;
15151515 GEMMVERBOSE (" xdnn_gemm_bf16f8bf16_compute_residential" ,
0 commit comments