@@ -21,10 +21,14 @@ def __init__(self,
21
21
block_m : Optional [int ] = None ,
22
22
allow_deep_gemm : bool = False ):
23
23
super ().__init__ ()
24
- self .triton_expert : TritonExperts = TritonExperts (
25
- use_fp8_w8a8 , use_int8_w8a8 , use_int4_w4a16 , use_int8_w8a16 ,
26
- per_channel_quant , block_shape , block_m )
27
- self .deep_gemm_expert : DeepGemmExperts = DeepGemmExperts ()
24
+ self .triton_expert = TritonExperts (use_fp8_w8a8 = use_fp8_w8a8 ,
25
+ use_int8_w8a8 = use_int8_w8a8 ,
26
+ use_int4_w4a16 = use_int4_w4a16 ,
27
+ use_int8_w8a16 = use_int8_w8a16 ,
28
+ per_channel_quant = per_channel_quant ,
29
+ block_shape = block_shape ,
30
+ block_m = block_m )
31
+ self .deep_gemm_expert = DeepGemmExperts ()
28
32
self .allow_deep_gemm = allow_deep_gemm
29
33
self .use_fp8_w8a8 = use_fp8_w8a8
30
34
@@ -69,7 +73,7 @@ def apply(
69
73
N = w1 .shape [1 ]
70
74
if (self .allow_deep_gemm and self .use_fp8_w8a8 and N > 512
71
75
and _valid_deep_gemm (hidden_states , w1 , w2 , expert_map )):
72
- return self .deep_gemm_expert (
76
+ return self .deep_gemm_expert . apply (
73
77
hidden_states ,
74
78
w1 ,
75
79
w2 ,
@@ -88,7 +92,7 @@ def apply(
88
92
expert_num_tokens ,
89
93
)
90
94
else :
91
- return self .triton_expert (
95
+ return self .triton_expert . apply (
92
96
hidden_states ,
93
97
w1 ,
94
98
w2 ,
0 commit comments