Skip to content

Commit 72ce564

Browse files
committed
check block gemm infer shape
1 parent 6e7bc6e commit 72ce564

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,11 @@ std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
199199
int rank = x_shape.size();
200200
int M = 0;
201201
int N = 0;
202-
if (x_shape[rank - 1] != x_scale_shape[rank - 1] * 128){
202+
if ((x_shape[rank - 1] + 127) / 128 != x_scale_shape[rank - 2]){
203203
PADDLE_THROW(phi::errors::Fatal(
204-
"cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-1] * 128 = x's dim[-1]."));
204+
"cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-2] * 128 = x's dim[-1]."));
205205
}
206-
if ((y_shape[rank - 1] != y_scale_shape[rank - 1] * 128) || (y_shape[rank - 2] != y_scale_shape[rank - 2] * 128)){
206+
if (((y_shape[rank - 1] + 127) / 128 != y_scale_shape[rank - 1]) || ((y_shape[rank - 2] + 127) / 128 != y_scale_shape[rank - 2])){
207207
PADDLE_THROW(phi::errors::Fatal(
208208
"cutlass_fp8_fp8_half_block_gemm_fused only support input y_scale's dim[-2:] * 128 = y's dim[-2:]."));
209209
}

0 commit comments

Comments
 (0)