Skip to content

Commit b04e5d3

Browse files
committed
lint
Signed-off-by: Bill Nell <[email protected]>
1 parent 0f711d8 commit b04e5d3

File tree

11 files changed

+89
-79
lines changed

11 files changed

+89
-79
lines changed

requirements/test.txt

+19-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ argcomplete==3.5.1
2727
# via datamodel-code-generator
2828
arrow==1.3.0
2929
# via isoduration
30+
async-timeout==5.0.1
31+
# via
32+
# aiohttp
33+
# redis
3034
attrs==24.2.0
3135
# via
3236
# aiohttp
@@ -126,6 +130,11 @@ encodec==0.1.1
126130
# via vocos
127131
evaluate==0.4.3
128132
# via lm-eval
133+
exceptiongroup==1.2.2
134+
# via
135+
# anyio
136+
# hypothesis
137+
# pytest
129138
fastparquet==2024.11.0
130139
# via genai-perf
131140
fastrlock==0.8.2
@@ -623,7 +632,6 @@ setuptools==77.0.3
623632
# via
624633
# mamba-ssm
625634
# pytablewriter
626-
# torch
627635
# triton
628636
shellingham==1.5.4
629637
# via typer
@@ -683,8 +691,13 @@ tokenizers==0.21.1
683691
# via
684692
# -r requirements/test.in
685693
# transformers
694+
toml==0.10.2
695+
# via datamodel-code-generator
686696
tomli==2.2.1
687-
# via schemathesis
697+
# via
698+
# black
699+
# pytest
700+
# schemathesis
688701
tomli-w==1.2.0
689702
# via schemathesis
690703
torch==2.7.0+cu128
@@ -756,12 +769,16 @@ types-python-dateutil==2.9.0.20241206
756769
# via arrow
757770
typing-extensions==4.12.2
758771
# via
772+
# anyio
773+
# black
759774
# huggingface-hub
760775
# librosa
761776
# mistral-common
777+
# multidict
762778
# pqdm
763779
# pydantic
764780
# pydantic-core
781+
# rich
765782
# torch
766783
# typer
767784
tzdata==2024.2

tests/kernels/moe/test_pplx_moe.py

-3
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,10 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
522522
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
523523
assert torch.cuda.current_device() == pgi.local_rank
524524

525-
hidden_dim = a.shape[1]
526525
num_experts = w1.shape[0]
527-
block_size = 128
528526
device = pgi.device
529527
rank = pgi.rank
530528
world_size = pgi.world_size
531-
topk = topk_ids.shape[1]
532529
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
533530

534531
dispatch_combine = BatchedDispatchCombine(

vllm/compilation/compiler_interface.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
328328
assert hash_str is not None, (
329329
f"failed to get the hash of the compiled graph: {file_path}")
330330
assert file_path is not None, (
331-
"failed to get the file path of the compiled graph: {file_path}")
331+
"failed to get the file path of the compiled graph: {file_path}"
332+
)
332333
return compiled_graph, (hash_str, file_path)
333334

334335
def load(self,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

+18-39
Original file line numberDiff line numberDiff line change
@@ -514,31 +514,18 @@ def dispatch(
514514
dtype=torch.int,
515515
device=a1.device)
516516

517-
rem_experts = num_experts % self.world_size
518-
num_local_experts = ((num_experts // self.world_size) +
519-
(1 if self.rank < rem_experts else 0))
517+
assert num_experts % self.world_size == 0
518+
519+
num_local_experts = num_experts // self.world_size
520520

521521
b_a1 = torch.zeros(
522522
(num_local_experts, self.max_num_tokens, hidden_dim),
523523
dtype=a1.dtype,
524524
device=a1.device)
525525

526-
first_expert = (((num_experts // self.world_size) * self.rank) +
527-
rem_experts - self.rank)
526+
first_expert = num_local_experts * self.rank
528527
last_expert = first_expert + num_local_experts
529528

530-
# rhs = torch.empty((self.max_num_tokens, hidden_dim),
531-
# dtype=a1.dtype, device=a1.device)
532-
533-
# for expert_id in range(first_expert, last_expert):
534-
# topks = torch.any(topk_ids == expert_id, dim=1).flatten()
535-
# rows = torch.count_nonzero(topks.flatten())
536-
# #rhs[:rows] = a1[:topks.numel()][topks]
537-
# topks_idx = topks.nonzero()
538-
# torch.index_select(a1, dim=0, index=topks_idx.flatten(), out=rhs[:rows])
539-
# b_a1[expert_id - first_expert, :rows, :] = rhs[:rows]
540-
# tokens_per_expert[expert_id - first_expert] = rows
541-
542529
for expert_id in range(first_expert, last_expert):
543530
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
544531
rows = torch.count_nonzero(topks.flatten())
@@ -558,24 +545,14 @@ def combine(
558545
) -> None:
559546
num_tokens = topk_ids.shape[0]
560547
num_local_experts = fused_expert_output.shape[0]
561-
topk = topk_weights.shape[1]
562548
K = fused_expert_output.shape[-1]
563549
assert output.shape[0] == num_tokens and output.shape[1] == K
564550

565551
output.fill_(0)
566552

567-
first_expert = num_local_experts * self.rank # NOT QUITE RIGHT
553+
first_expert = num_local_experts * self.rank
568554
last_expert = first_expert + num_local_experts
569555

570-
# for expert_id in range(first_expert, last_expert):
571-
# topkws = topk_ids == expert_id
572-
# topks = torch.any(topkws, dim=1).flatten()
573-
# outrhs = output[topks]
574-
# rhs = fused_expert_output[expert_id - first_expert, :outrhs.shape[0], :]
575-
# if not apply_router_weight_on_input:
576-
# rhs.mul_(topk_weights[topkws].view(rhs.shape[0], 1))
577-
# output[topks] = outrhs + rhs
578-
579556
for expert_id in range(first_expert, last_expert):
580557
topkws = topk_ids == expert_id
581558
topks = torch.any(topkws, dim=1).flatten()
@@ -661,20 +638,20 @@ def apply(
661638
num_experts = global_num_experts
662639
out = _resize_cache(workspace13,
663640
(num_experts, max_num_tokens * num_dp, hidden_dim))
664-
num_local_experts = w1.shape[0] #expert_num_tokens.numel()
641+
num_local_experts = w1.shape[0]
665642
assert num_local_experts == w1.shape[
666643
0], f"{num_local_experts} == {w1.shape[0]}"
667644

668645
N = w1.shape[1] // 2
669646

670647
# Not cudagraph friendly
671-
# assert (torch.cuda.is_current_stream_capturing() or
672-
# torch.all(expert_num_tokens <= max_num_tokens)), (
673-
# f"{expert_num_tokens} <= {max_num_tokens}")
648+
assert (torch.cuda.is_current_stream_capturing()
649+
or torch.all(expert_num_tokens <= max_num_tokens)), (
650+
f"{expert_num_tokens} <= {max_num_tokens}")
674651

675652
for expert in range(num_local_experts):
676653
# Indexing expert_num_tokens doesn't work w/cudagraphs
677-
if True or torch.cuda.is_current_stream_capturing():
654+
if torch.cuda.is_current_stream_capturing():
678655
num = max_num_tokens * num_dp
679656
else:
680657
num = int(expert_num_tokens[expert].item())
@@ -821,12 +798,14 @@ def apply(
821798
block_shape=self.block_shape)
822799

823800
# Fix activations
824-
# assert activation == "silu"
825-
# invoke_batched_silu_and_mul(output=intermediate_cache2,
826-
# input=intermediate_cache1,
827-
# expert_num_tokens=expert_num_tokens)
828-
self.activation(activation, intermediate_cache2.view(-1, N // 2),
829-
intermediate_cache1.view(-1, N))
801+
if True:
802+
assert activation == "silu"
803+
invoke_batched_silu_and_mul(output=intermediate_cache2,
804+
input=intermediate_cache1,
805+
expert_num_tokens=expert_num_tokens)
806+
else:
807+
self.activation(activation, intermediate_cache2.view(-1, N // 2),
808+
intermediate_cache1.view(-1, N))
830809

831810
#qintermediate_cache2 = intermediate_cache2
832811
a2q_scale = a2_scale

vllm/model_executor/layers/fused_moe/layer.py

+43-27
Original file line numberDiff line numberDiff line change
@@ -68,55 +68,68 @@ def use_pplx_kernels(self):
6868
def make(tp_size_: int, dp_size_: int,
6969
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
7070
"""
71-
Determine MoE parallel configuration. Based on the input tp_size_, dp_size_,
72-
ep_size_ and vllm's parallel config, determine what level's of parallelism
73-
to use in the fused moe layer.
71+
Determine MoE parallel configuration. Based on the input tp_size_,
72+
dp_size_, ep_size_ and vllm's parallel config, determine what
73+
level's of parallelism to use in the fused moe layer.
7474
7575
Args:
7676
tp_size_ (int): tp_size passed into the FusedMoE constructor.
7777
dp_size_ (int): dp_size passed into the FusedMoE constructor.
7878
ep_size_ (int): ep_size passed into the FusedMoE constructor.
79-
vllm_parallel_config (ParallelConfig): vllm's parallel config object.
79+
vllm_parallel_config (ParallelConfig): vllm's parallel config
80+
object.
8081
8182
Examples:
8283
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
8384
we simply return the sizes unaltered and the ranks set to 0.
8485
85-
Expert Parallelism is considered only when either dp_size_ or tp_size_ is non trivial.
86+
Expert Parallelism is considered only when either dp_size_ or tp_size_
87+
is non trivial.
8688
87-
When TP = 2, DP = 1 and EP = False, the configuration on different devices,
88-
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // legend : {size, rank}
89+
When TP = 2, DP = 1 and EP = False, the configuration on different
90+
devices,
91+
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
92+
legend : {size, rank}
8993
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
9094
- Comment : Tensors are sharded across 2 devices.
9195
92-
When TP = 1, DP = 2 and EP = False, the configuration on different devices,
96+
When TP = 1, DP = 2 and EP = False, the configuration on different
97+
devices,
9398
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
9499
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
95-
- Comment: There are 2 engine instances and the tensors are sharded across 2 decvices.
100+
- Comment: There are 2 engine instances and the tensors are sharded
101+
across 2 decvices.
96102
97-
When TP = 2, DP = 2 and EP = False, the configuration on different devices,
103+
When TP = 2, DP = 2 and EP = False, the configuration on different
104+
devices,
98105
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
99106
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
100107
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
101108
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
102-
- Comment: There are 2 engine instances and the tensors are sharded across 4 devices.
109+
- Comment: There are 2 engine instances and the tensors are sharded
110+
across 4 devices.
103111
104-
When, TP = 2, DP = 1 and EP = True, the configuration on different devices,
112+
When, TP = 2, DP = 1 and EP = True, the configuration on different
113+
devices,
105114
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
106115
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
107116
- Comment: The experts are split between the 2 devices.
108117
109-
When, TP = 1, DP = 2 and EP = True, the configuration on different devices,
118+
When, TP = 1, DP = 2 and EP = True, the configuration on different
119+
devices,
110120
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
111121
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
112-
- Comment: There are 2 engine instances and the experts are split between the 2 devices.
122+
- Comment: There are 2 engine instances and the experts are split
123+
between the 2 devices.
113124
114-
When TP = 2, DP = 2 and EP = True, the configuration on different devices,
125+
When TP = 2, DP = 2 and EP = True, the configuration on different
126+
devices,
115127
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
116128
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
117129
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
118130
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
119-
- Comment: There are 2 engine instances and the experts are split between the 4 devices.
131+
- Comment: There are 2 engine instances and the experts are split
132+
between the 4 devices.
120133
"""
121134

122135
def flatten_tp_across_dp(dp_rank: int):
@@ -127,7 +140,8 @@ def flatten_tp_across_dp(dp_rank: int):
127140
tp_rank = dp_rank * tp_size_ + tp_rank
128141
return tp_size, tp_rank
129142

130-
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
143+
use_ep = (dp_size_ * tp_size_ > 1
144+
and vllm_parallel_config.enable_expert_parallel)
131145

132146
dp_size = dp_size_
133147
dp_rank = get_dp_group().rank_in_group
@@ -143,8 +157,8 @@ def flatten_tp_across_dp(dp_rank: int):
143157
use_ep=False)
144158
# DP + EP / TP + EP / DP + TP + EP
145159
assert use_ep
146-
# In EP, each device owns a set of experts fully. There is no tensor parallel.
147-
# Update tp_size, tp_rank, ep_size and ep_rank to reflect that.
160+
# In EP, each device owns a set of experts fully. There is no tensor
161+
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
148162
ep_size = tp_size
149163
ep_rank = tp_rank
150164
return FusedMoEParallelConfig(tp_size=1,
@@ -719,12 +733,13 @@ def __init__(
719733
self.params_dtype = params_dtype
720734

721735
vllm_config = get_current_vllm_config()
722-
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
723-
tp_size_=(tp_size if tp_size is not None else
724-
get_tensor_model_parallel_world_size()),
725-
dp_size_=(dp_size
726-
if dp_size is not None else get_dp_group().world_size),
727-
vllm_parallel_config=vllm_config.parallel_config)
736+
self.moe_parallel_config: FusedMoEParallelConfig = (
737+
FusedMoEParallelConfig.make(
738+
tp_size_=(tp_size if tp_size is not None else
739+
get_tensor_model_parallel_world_size()),
740+
dp_size_=(dp_size if dp_size is not None else
741+
get_dp_group().world_size),
742+
vllm_parallel_config=vllm_config.parallel_config))
728743

729744
self.global_num_experts = num_experts
730745

@@ -1184,8 +1199,9 @@ def must_reduce_shared_outputs(self) -> bool:
11841199
def maybe_all_reduce_tensor_model_parallel(
11851200
self, final_hidden_states: torch.Tensor):
11861201
"""
1187-
The pplx combine kernel reduce across GPU ranks by default. The pplx kernels are
1188-
used when EP is enabled. In that case, this function is a no-op.
1202+
The pplx combine kernel reduce across GPU ranks by default. The pplx
1203+
kernels are used when EP is enabled. In that case, this function is a
1204+
no-op.
11891205
"""
11901206
if self.dp_size > 1 and self.use_ep and has_pplx:
11911207
return final_hidden_states

vllm/model_executor/models/deepseek_v2.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def __init__(
145145
# to reduce the shared_output result. Instead we reduce
146146
# at the end of the forward pass.
147147
# With EP and the pplx kernels - this is no longer viable
148-
# as all GPU ranks in DP, produce the complete set of hidden_states.
148+
# as all GPU ranks in DP, produce the complete set of
149+
# hidden_states.
149150
# Therefore reduce the shared experts early.
150151
reduce_results=self.experts.must_reduce_shared_outputs(),
151152
prefix=f"{prefix}.shared_experts",
@@ -178,7 +179,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
178179
* (1. / self.routed_scaling_factor)
179180

180181
if self.tp_size > 1:
181-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
182+
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
182183
final_hidden_states)
183184

184185
return final_hidden_states.view(num_tokens, hidden_dim)

vllm/model_executor/models/granitemoe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
100100
final_hidden_states = self.experts(hidden_states, router_logits)
101101

102102
if self.tp_size > 1:
103-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
103+
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
104104
final_hidden_states)
105105

106106
return final_hidden_states.view(orig_shape)

vllm/model_executor/models/qwen2_moe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
154154
if shared_output is not None:
155155
final_hidden_states = final_hidden_states + shared_output
156156
if self.tp_size > 1:
157-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
157+
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
158158
final_hidden_states)
159159

160160
return final_hidden_states.view(orig_shape)

vllm/model_executor/models/qwen3_moe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
135135
router_logits=router_logits)
136136
final_hidden_states = final_hidden_states
137137
if self.tp_size > 1:
138-
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
138+
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
139139
final_hidden_states)
140140

141141
return final_hidden_states.view(orig_shape)

vllm/platforms/cuda.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
157157
logger.info(
158158
"Forcing kv cache block size to 64 for FlashMLA backend.")
159159

160-
if (False and parallel_config.data_parallel_size > 1
160+
if (parallel_config.data_parallel_size > 1
161161
and compilation_config.use_cudagraph):
162162
logger.info(
163163
"Data Parallel: Forcing enforce eager to be True since DP is "

vllm/v1/worker/gpu_model_runner.py

-1
Original file line numberDiff line numberDiff line change
@@ -1542,7 +1542,6 @@ def _dummy_run(
15421542
self.drafter.dummy_run(num_tokens)
15431543

15441544
logit_indices = np.cumsum(num_scheduled_tokens) - 1
1545-
#logit_indices = torch.from_numpy(logit_indices).to(hidden_states.device)
15461545
return hidden_states[logit_indices]
15471546

15481547
@torch.inference_mode()

0 commit comments

Comments
 (0)