2121
2222
2323def get_context_mla_test_cases ():
24- dtype_list = [tensorrt_llm .bindings .DataType .FP8 , tensorrt_llm . bindings . DataType . BF16 ]
24+ dtype_list = [tensorrt_llm .bindings .DataType .BF16 ] # not support f8 for trt < v1.1
2525 test_cases = []
26- n_list = [64 , 128 ]
26+ n_list = [128 ]
2727 b_list = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
2828 s_list = [
2929 16 ,
@@ -47,7 +47,7 @@ def get_context_mla_test_cases():
4747 for b in b_list :
4848 for s in s_list : # [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072]:
4949 for dtype in dtype_list :
50- for tp_size in [1 , 2 , 4 , 8 , 16 , 32 , 64 ]:
50+ for tp_size in [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]:
5151 if b * s > 32768 :
5252 continue
5353 # (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
@@ -72,9 +72,9 @@ def get_context_mla_test_cases():
7272
7373
7474def get_generation_mla_test_cases ():
75- dtype_list = [tensorrt_llm .bindings .DataType .FP8 , tensorrt_llm . bindings . DataType . BF16 ]
75+ dtype_list = [tensorrt_llm .bindings .DataType .BF16 ] # not support f8 for trt < v1.1
7676 test_cases = []
77- n_list = [64 , 128 ]
77+ n_list = [128 ]
7878 for n in n_list :
7979 for b in [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 ]:
8080 for s in [
@@ -97,7 +97,7 @@ def get_generation_mla_test_cases():
9797 131072 ,
9898 ]: # [target token s] is equivalent to [in: s-1, step=1]
9999 for dtype in dtype_list :
100- for tp_size in [1 , 2 , 4 , 8 , 16 , 32 , 64 ]:
100+ for tp_size in [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ]:
101101 if b * s > 1024 * 4096 * 2 * 2 :
102102 continue
103103 # (input_len, batch_size, output_len, kv_cache_dtype, num_heads, world_size,
@@ -140,13 +140,11 @@ def run_mla(
140140 torch .cuda .set_device (device )
141141 backend_name = "TRTLLM"
142142 layer_idx = 0
143- num_key_value_heads = num_heads
144-
145- assert num_key_value_heads % tp_size == 0 , "num_key_value_heads != N * tp_size"
146- num_key_value_heads = int (num_key_value_heads / tp_size )
147143
144+ assert kv_cache_dtype == tensorrt_llm .bindings .DataType .BF16 , "only support bfloat16 for trtllm"
148145 assert num_heads % tp_size == 0 , "num_heads != N * tp_size"
149- num_heads = int (num_heads / tp_size )
146+ num_heads = num_heads // tp_size
147+ num_key_value_heads = num_heads
150148
151149 pos_embd_params = PositionalEmbeddingParams (
152150 type = PositionEmbeddingType .yarn ,
@@ -170,7 +168,6 @@ def run_mla(
170168 )
171169
172170 quant_config = QuantConfig (
173- quant_algo = "FP8_BLOCK_SCALES" ,
174171 kv_cache_quant_algo = None ,
175172 group_size = None ,
176173 smoothquant_val = 0.5 ,
@@ -391,15 +388,11 @@ def run_mla(
391388 isl = 1
392389 step = input_len
393390
394- dtype_str = "float16"
395- if kv_cache_dtype == tensorrt_llm .bindings .DataType .FP8 :
396- dtype_str = "fp8"
397-
398391 log_perf (
399392 item_list = [
400393 {
401394 "mla_dtype" : "float16" ,
402- "kv_cache_dtype" : dtype_str ,
395+ "kv_cache_dtype" : "float16" ,
403396 "num_heads" : num_heads ,
404397 "batch_size" : batch_size ,
405398 "isl" : isl ,
0 commit comments