File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
csrc/gpu/fp8_gemm_with_cutlass Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -199,11 +199,11 @@ std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfBlockGemmFusedInferShape(
199
199
int rank = x_shape.size ();
200
200
int M = 0 ;
201
201
int N = 0 ;
202
- if (x_shape[rank - 1 ] != x_scale_shape[rank - 1 ] * 128 ){
202
+ if (( x_shape[rank - 1 ] + 127 ) / 128 != x_scale_shape[rank - 2 ] ){
203
203
PADDLE_THROW (phi::errors::Fatal (
204
- " cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-1 ] * 128 = x's dim[-1]." ));
204
+ " cutlass_fp8_fp8_half_block_gemm_fused only support x_scale's dim[-2 ] * 128 = x's dim[-1]." ));
205
205
}
206
- if ((y_shape[rank - 1 ] != y_scale_shape[rank - 1 ] * 128 ) || (y_shape[rank - 2 ] != y_scale_shape[rank - 2 ] * 128 )){
206
+ if ((( y_shape[rank - 1 ] + 127 ) / 128 != y_scale_shape[rank - 1 ]) || (( y_shape[rank - 2 ] + 127 ) / 128 != y_scale_shape[rank - 2 ])){
207
207
PADDLE_THROW (phi::errors::Fatal (
208
208
" cutlass_fp8_fp8_half_block_gemm_fused only support input y_scale's dim[-2:] * 128 = y's dim[-2:]." ));
209
209
}
You can’t perform that action at this time.
0 commit comments