Skip to content

Commit d11c0cf

Browse files
authored
feat: support gptoss in aiconfigurator sdk (#56)
* rebase gpt oss related changes * revert change to b_list * fallback to sol estimation when moe_dict is none * update sol calculation of context attention * update test to support variable head_size and window_size * update context attention test case to query specific head_size and window_size * update gen attention test case * update load attention data tests * update edge test case * update test_correct_generation_attention_data
1 parent e46d908 commit d11c0cf

File tree

11 files changed

+220
-120
lines changed

11 files changed

+220
-120
lines changed

collector/trtllm/collect_attn.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,14 @@ def get_context_attention_test_cases():
272272
#print(f'collecting heads: {n} kv_heads: {num_kv_heads} seq: {s} batchsize: {b}')
273273
# use fp8 kv cache, fp8 context fmha, is_context_phase. in torch flow, int8 kvcache is not supported yet.
274274
# fp16 kv cache, fp16 context fmha, is_context_phase
275-
if head_dim == 64:
275+
if h == 64:
276276
test_cases.append([b, s, n, num_kv_heads, h, 128, False, False, True, 'context_attention_perf.txt'])
277+
test_cases.append([b, s, n, num_kv_heads, h, 0, False, False, True, 'context_attention_perf.txt'])
277278
if has_fp8:
278279
test_cases.append([b, s, n, num_kv_heads, h, 128, True, False, True, 'context_attention_perf.txt'])
279280
test_cases.append([b, s, n, num_kv_heads, h, 128, True, True, True, 'context_attention_perf.txt'])
281+
test_cases.append([b, s, n, num_kv_heads, h, 0, True, False, True, 'context_attention_perf.txt'])
282+
test_cases.append([b, s, n, num_kv_heads, h, 0, True, True, True, 'context_attention_perf.txt'])
280283
else:
281284
test_cases.append([b, s, n, num_kv_heads, h, 0, False, False, True, 'context_attention_perf.txt'])
282285
if has_fp8:
@@ -375,10 +378,12 @@ def get_generation_attention_test_cases():
375378
maxNumHeadsQPerKvInCta = 32
376379
if mNumHeadsQPerKv >= maxNumHeadsQPerKvInCta and mNumHeadsQPerKv % maxNumHeadsQPerKvInCta != 0:
377380
continue
378-
if head_dim == 64:
381+
if h == 64:
379382
test_cases.append([b, s, n, n_kv, h, 128, False, False, False, 'generation_attention_perf.txt'])
383+
test_cases.append([b, s, n, n_kv, h, 0, False, False, False, 'generation_attention_perf.txt'])
380384
if has_fp8:
381385
test_cases.append([b, s, n, n_kv, h, 128, True, False, False, 'generation_attention_perf.txt'])
386+
test_cases.append([b, s, n, n_kv, h, 0, True, False, False, 'generation_attention_perf.txt'])
382387
# currently, fp8 is not for generation compute
383388
#test_cases.append([b, s, n, n_kv, 128, True, True, False, 'generation_attention_perf.txt'])
384389
else:

collector/trtllm/collect_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ def run_moe_torch(moe_type, num_tokens_lists, hidden_size, inter_size, topk, num
242242
swiglu_limit = torch.tensor(
243243
[7.0] * (num_experts // moe_ep_size),
244244
dtype=torch.float32).cuda()
245+
if 86 < getSMVersion() <100:
246+
model_config.moe_backend = 'triton'
247+
else:
248+
model_config.moe_backend = 'cutlass' if not min_latency_mode else 'trtllm'
245249
else:
246250
model_config.moe_backend = 'cutlass' if not min_latency_mode else 'trtllm'
247251

src/aiconfigurator/sdk/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ class BlockConfig:
7070
BlockConfig(None, True, 3.28125, False, 1),
7171
BlockConfig(None, True, 5.25, False, 1)
7272
]
73-
]
73+
],
74+
'GPT_OSS_120B':['MOE',36,64,8,64,2880,2880,201088,131072,4,128,2880,None],
75+
'GPT_OSS_20B':['MOE',24,64,8,64,2880,2880,201088,131072,4,32,2880,None]
7476
}
7577

7678
"""
@@ -186,6 +188,7 @@ class MoEQuantMode(Enum):
186188
fp8_block = QuantMapping(1, 2, 'fp8_block') # specific for trtllm torch ds fp8
187189
w4afp8 = QuantMapping(0.5, 2, 'w4afp8') # specific for trtllm torch ds w4a8
188190
nvfp4 = QuantMapping(0.5, 4, 'nvfp4') # nvfp4 on blackwell
191+
w4a16_mxfp4 = QuantMapping(0.5, 1, 'w4a16_mxfp4') #native data format for gpt oss
189192

190193
class FMHAQuantMode(Enum):
191194
"""

src/aiconfigurator/sdk/models.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,33 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N
266266
fmha_quant_mode = self.config.fmha_quant_mode
267267
workload_distribution = self.config.workload_distribution + f"_{self._power_law_alpha}"
268268

269+
if self.model_name in ['GPT_OSS_120B','GPT_OSS_20B']:
270+
attn_scale_factor = 2
271+
window_size = 128
272+
self.context_ops.append(ops.ContextAttention(f'context_attention',
273+
self._num_layers/attn_scale_factor,
274+
self._num_heads//tp_size,
275+
num_kv_heads_per_GPU,
276+
kvcache_quant_mode,
277+
fmha_quant_mode,
278+
window_size,
279+
self._head_size))
280+
self.generation_ops.append(ops.GenerationAttention(f'generation_attention',
281+
self._num_layers/attn_scale_factor,
282+
self._num_heads//tp_size,
283+
num_kv_heads_per_GPU,
284+
kvcache_quant_mode,
285+
window_size,
286+
self._head_size))
287+
else:
288+
attn_scale_factor = 1
289+
269290
self.context_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3),
270291
ops.ElementWise(f'context_add_norm_1', self._num_layers, 2*h, 2*h, 0.8),
271292
ops.GEMM(f'context_qkv_gemm', self._num_layers, self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode),
272-
ops.ContextAttention(f'context_attention', self._num_layers, self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode, fmha_quant_mode),
293+
ops.ContextAttention(f'context_attention', self._num_layers/attn_scale_factor,
294+
self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode,
295+
fmha_quant_mode, head_size=self._head_size),
273296
ops.GEMM(f'context_proj_gemm', self._num_layers, h, self._num_heads*self._head_size//tp_size, gemm_quant_mode),
274297
ops.ElementWise(f'context_add_norm_2', self._num_layers, 2*h, 2*h, 0.8)])
275298

@@ -290,7 +313,9 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N
290313
self.generation_ops.extend([ops.Embedding(f'generation_embedding', 1, self._vocab_size, h, 0.3),
291314
ops.ElementWise(f'generation_add_norm_1', self._num_layers, 2*h, 2*h, 0.8),
292315
ops.GEMM(f'generation_qkv_gemm', self._num_layers, self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode),
293-
ops.GenerationAttention(f'generation_attention', self._num_layers, self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode),
316+
ops.GenerationAttention(f'generation_attention', self._num_layers/attn_scale_factor,
317+
self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode,
318+
head_size=self._head_size),
294319
ops.GEMM(f'generation_proj_gemm', self._num_layers, h, self._num_heads*self._head_size//tp_size, gemm_quant_mode),
295320
ops.ElementWise(f'generation_add_norm_2', self._num_layers, 2*h, 2*h, 0.8)])
296321

src/aiconfigurator/sdk/operations.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,18 +317,25 @@ def __init__(self,
317317
n: int,
318318
n_kv: int,
319319
kvcache_quant_mode: common.KVCacheQuantMode,
320-
fmha_quant_mode: common.FMHAQuantMode) -> None:
320+
fmha_quant_mode: common.FMHAQuantMode,
321+
window_size: int = 0,
322+
head_size: int = 128) -> None:
321323
super().__init__(name, scale_factor)
322324
self._n = n
323325
self._weights = 0.0
324326
self._n_kv = n_kv
325327
self._kvcache_quant_mode = kvcache_quant_mode
326328
self._fmha_quant_mode = fmha_quant_mode
329+
self._window_size = window_size
330+
self._head_size = head_size
327331

328332
def query(self, database:PerfDatabase, **kwargs):
329333
batch_size = kwargs.get('batch_size')
330334
isl = kwargs.get('s')
331-
return database.query_context_attention(batch_size, isl, self._n, self._n_kv, self._kvcache_quant_mode, self._fmha_quant_mode)*self._scale_factor
335+
return database.query_context_attention(batch_size, isl, self._n, self._n_kv,
336+
self._kvcache_quant_mode, self._fmha_quant_mode,
337+
window_size=self._window_size,
338+
head_size=self._head_size)*self._scale_factor
332339

333340
def get_weights(self, **kwargs):
334341
return self._weights * self._scale_factor
@@ -342,19 +349,26 @@ def __init__(self,
342349
scale_factor: float,
343350
n: int,
344351
n_kv: int,
345-
kv_cache_dtype: common.KVCacheQuantMode) -> None:
352+
kv_cache_dtype: common.KVCacheQuantMode,
353+
window_size: int = 0,
354+
head_size: int = 128) -> None:
346355
super().__init__(name, scale_factor)
347356
self._n = n
348357
self._weights = 0.0
349358
self._n_kv = n_kv
350359
self._kv_cache_dtype = kv_cache_dtype
360+
self._window_size = window_size
361+
self._head_size = head_size
351362

352363
def query(self, database:PerfDatabase, **kwargs):
353364
beam_width = kwargs.get('beam_width')
354365
assert(beam_width == 1), "only support beam_width=1"
355366
batch_size = kwargs.get('batch_size')
356367
s = kwargs.get('s')
357-
return database.query_generation_attention(batch_size, s, self._n, self._n_kv, self._kv_cache_dtype)*self._scale_factor
368+
return database.query_generation_attention(batch_size, s, self._n, self._n_kv,
369+
self._kv_cache_dtype,
370+
window_size=self._window_size,
371+
head_size=self._head_size)*self._scale_factor
358372

359373
def get_weights(self, **kwargs):
360374
return self._weights * self._scale_factor

0 commit comments

Comments
 (0)