Skip to content

Commit

Permalink
check block gemm infer shape
Browse files Browse the repository at this point in the history
  • Loading branch information
ckl117 committed Feb 22, 2025
1 parent 6e7bc6e commit 72ce564
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_block_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
int rank = x_shape.size();
int M = 0;
int N = 0;
if (x_shape[rank - 1] != x_scale_shape[rank - 1] * 128){
if ((x_shape[rank - 1] + 127) / 128 != x_scale_shape[rank - 2]){
PADDLE_THROW(phi::errors::Fatal(
"cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-1] * 128 = x's dim[-1]."));
"cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-2] * 128 = x's dim[-1]."));
}
if ((y_shape[rank - 1] != y_scale_shape[rank - 1] * 128) || (y_shape[rank - 2] != y_scale_shape[rank - 2] * 128)){
if (((y_shape[rank - 1] + 127) / 128 != y_scale_shape[rank - 1]) || ((y_shape[rank - 2] + 127) / 128 != y_scale_shape[rank - 2])){
PADDLE_THROW(phi::errors::Fatal(
"cutlass_fp8_fp8_half_block_gemm_fused only support input y_scale's dim[-2:] * 128 = y's dim[-2:]."));
}
Expand Down

0 comments on commit 72ce564

Please sign in to comment.