Skip to content

Commit d1a2ea0

Browse files
wip
1 parent 1a12821 commit d1a2ea0

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

tpu_inference/layers/vllm/fused_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tpu_inference.layers.vllm.linear_common import \
1010
slice_sharded_tensor_for_concatenation
11-
from tpu_inference.layers.common.sharding import ShardingAxisName
11+
1212
P = PartitionSpec
1313

1414

@@ -374,6 +374,7 @@ def fused_moe_func(
374374
assert (num_tokens * topk) % 16 == 0, (
375375
"The kernel requires num_tokens * topk to be a multiple of "
376376
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
377+
377378
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
378379
gating_output = gating_output.reshape(num_tokens, global_num_experts)
379380

@@ -425,6 +426,7 @@ def _process_tokens_locally(hidden_states_local, topk_indices_local):
425426
)
426427

427428
x = activation_fn(activation, x1, x2)
429+
428430
if use_ep:
429431
x = expert_sharded_gmm(
430432
x,

tpu_inference/layers/vllm/quantization/unquantized.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
2626
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
2727
get_tpu_quant_method)
28-
from tpu_inference.layers.common.sharding import ShardingAxisName
29-
3028
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
3129
from tpu_inference.layers.vllm.linear_common import (
3230
reorder_concatenated_tensor_for_sharding,

tpu_inference/layers/vllm/sharding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from vllm.model_executor.layers.vocab_parallel_embedding import (
2020
ParallelLMHead, VocabParallelEmbedding)
2121

22-
from tpu_inference.layers.common.sharding import ShardingAxisName
2322
from tpu_inference import envs
2423
from tpu_inference.logger import init_logger
2524

0 commit comments

Comments
 (0)