Skip to content

[Bug] Run DeepSeek-v3-0324 on 2xH20 encountered cuda illegal memory access #4856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
5 tasks done
cscyuge opened this issue Mar 28, 2025 · 7 comments
Open
5 tasks done

Comments

@cscyuge
Copy link

cscyuge commented Mar 28, 2025

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Run DeepSeek-v3-0324 on 2xH20 encountered "RuntimeError: CUDA error: an illegal memory access was encountered".

output logs:

[2025-03-28 09:41:59] INFO:     127.0.0.1:53082 - "POST /generate HTTP/1.1" 200 OK
[2025-03-28 09:41:59] INFO:     127.0.0.1:53084 - "POST /generate HTTP/1.1" 200 OK
[2025-03-28 09:41:59 TP0] Prefill batch. #new-seq: 5, #new-token: 59995, #cached-token: 10, token usage: 0.05, #running-req: 1, #queue-req: 58,

  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker.py", line 172, in forward_batch_generation
    logits_output = self.model_runner.forward(forward_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 989, in forward
    return self.forward_extend(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 950, in forward_extend
    return self.model.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1208, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1168, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1102, in forward
    hidden_states = self.mlp(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 252, in forward
    return self.forward_normal(hidden_states)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 262, in forward_normal
    self.experts(hidden_states=hidden_states, router_logits=router_logits)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 620, in forward
    final_hidden_states = self.quant_method.apply(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 970, in apply
    return fused_experts(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 915, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 784, in inplace_fused_experts
    return forward_call(*args, **kwargs)                                                                                                                                                               [10/1985]
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/layer.py", line 620, in forward
    final_hidden_states = self.quant_method.apply(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/quantization/fp8.py", line 970, in apply
    return fused_experts(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 915, in fused_experts
    torch.ops.sglang.inplace_fused_experts(
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 784, in inplace_fused_experts
    fused_experts_impl(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 1108, in fused_experts_impl
    invoke_fused_moe_kernel(
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py", line 567, in invoke_fused_moe_kernel
    fused_moe_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 111, in forward_thread_func
    with torch.get_device_module(self.device).stream(self.forward_stream):
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 595, in __exit__
    torch.cuda.set_stream(self.src_prev_stream)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 962, in _patched_set_stream
    prev_set_stream(stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 636, in set_stream
    _set_stream_by_id(
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 618, in _set_stream_by_id
    torch._C._cuda_setStream(
RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


[2025-03-28 09:42:00] Received sigquit from a child process. It usually means the child failed.

Reproduction

model: DeepSeek-V3-0324(https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/tree/main)

running params:
node 0:

#/bin/bash
export SERVER_IP="0.0.0.0"
export SERVER_PORT=30066
export MODEL_PATH="/root/model-DeepSeek/V3_0324/DeepSeek-V3-0324/"
export TP_SIZE=16
export NCCL_IB_GID_INDEX=3
export GLOO_SOCKET_IFNAME=eth0
export NCCL_SOCKET_IFNAME=eth0
export MASTER_ADDR=MY_MASTER_ADDR
export CUDA_LAUNCH_BLOCKING=1

python3 -m sglang.launch_server \
        --host ${SERVER_IP} \
        --port ${SERVER_PORT} \
        --model-path ${MODEL_PATH} \
        --tp ${TP_SIZE} \
        --dist-init-addr ${MASTER_ADDR} \
        --nnodes 2 \
	--node-rank 0 \
        --enable-metrics \
        --chunked-prefill-size -1 \
        --max-prefill-tokens 65536 \
        --context-length 65536 \
        --mem-fraction-static 0.6 \
        --trust-remote-code

node 1:

#/bin/bash
export SERVER_IP="0.0.0.0"
export SERVER_PORT=30066
export MODEL_PATH="/root/model-DeepSeek/V3_0324/DeepSeek-V3-0324/"
export TP_SIZE=16
export NCCL_IB_GID_INDEX=3
export GLOO_SOCKET_IFNAME=eth0
export NCCL_SOCKET_IFNAME=eth0
export MASTER_ADDR=MY_MASTER_ADDR
export CUDA_LAUNCH_BLOCKING=1

python3 -m sglang.launch_server \
        --host ${SERVER_IP} \
        --port ${SERVER_PORT} \
        --model-path ${MODEL_PATH} \
        --tp ${TP_SIZE} \
        --dist-init-addr ${MASTER_ADDR} \
        --nnodes 2 \
	--node-rank 1 \
        --enable-metrics \
        --chunked-prefill-size -1 \
        --max-prefill-tokens 65536 \
        --context-length 65536 \
        --mem-fraction-static 0.6 \
        --trust-remote-code

benchmark script:

INPUT_LEN=12000
OUTPUT_LEN=1000
RATE=32
CONCURRENCY=64
PROMPTS=128
python3 -m sglang.bench_serving \
        --host 0.0.0.0 \
        --port 30066 \
        --backend sglang \
        --dataset-name random \
        --random-input $INPUT_LEN \
        --random-output $OUTPUT_LEN \
        --random-range-ratio 1 \
        --request-rate $RATE \
        --max-concurrency $CONCURRENCY \
        --num-prompts $PROMPTS \
        --dataset-path /root/data/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json

Environment

python3 -m sglang.check_env outputs:

INFO 03-28 09:34:36 __init__.py:190] Automatically detected platform cuda.
Python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 535.161.07
PyTorch: 2.5.1+cu124
sglang: 0.4.5.post1
sgl_kernel: 0.0.5.post3
flashinfer: 0.2.3+cu124torch2.5
triton: 3.1.0
transformers: 4.50.0
torchao: 0.9.0
numpy: 1.26.4
aiohttp: 3.11.13
fastapi: 0.115.11
hf_transfer: 0.1.9
huggingface_hub: 0.29.3
interegular: 0.3.3
modelscope: 1.23.2
orjson: 3.10.15
packaging: 24.2
psutil: 7.0.0
pydantic: 2.10.6
multipart: 0.0.20
zmq: 26.3.0
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: 0.7.2
openai: 1.66.3
tiktoken: 0.9.0
anthropic: 0.49.0
decord: 0.6.0
NVIDIA Topology: 
	�[4mGPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	NIC0	NIC1	NIC2	NIC3	NIC4	NIC5	NIC6	NIC7	CPU Affinity	NUMA Affinity	GPU NUMA ID�[0m
GPU0	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NV18	PIX	NODE	NODE	NODE	SYS	SYS	SYS	SYS	0-95,192-287	0		N/A
GPU1	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NODE	PIX	PHB	NODE	SYS	SYS	SYS	SYS	0-95,192-287	0		N/A
GPU2	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NODE	PHB	PIX	NODE	SYS	SYS	SYS	SYS	0-95,192-287	0		N/A
GPU3	NV18	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NODE	NODE	NODE	PIX	SYS	SYS	SYS	SYS	0-95,192-287	0		N/A
GPU4	NV18	NV18	NV18	NV18	 X 	NV18	NV18	NV18	SYS	SYS	SYS	SYS	PIX	NODE	NODE	NODE	96-191,288-383	1		N/A
GPU5	NV18	NV18	NV18	NV18	NV18	 X 	NV18	NV18	SYS	SYS	SYS	SYS	NODE	PIX	NODE	NODE	96-191,288-383	1		N/A
GPU6	NV18	NV18	NV18	NV18	NV18	NV18	 X 	NV18	SYS	SYS	SYS	SYS	NODE	NODE	PIX	PHB	96-191,288-383	1		N/A
GPU7	NV18	NV18	NV18	NV18	NV18	NV18	NV18	 X 	SYS	SYS	SYS	SYS	NODE	NODE	PHB	PIX	96-191,288-383	1		N/A
NIC0	PIX	NODE	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	NODE	NODE	SYS	SYS	SYS	SYS				
NIC1	NODE	PIX	PHB	NODE	SYS	SYS	SYS	SYS	NODE	 X 	PHB	NODE	SYS	SYS	SYS	SYS				
NIC2	NODE	PHB	PIX	NODE	SYS	SYS	SYS	SYS	NODE	PHB	 X 	NODE	SYS	SYS	SYS	SYS				
NIC3	NODE	NODE	NODE	PIX	SYS	SYS	SYS	SYS	NODE	NODE	NODE	 X 	SYS	SYS	SYS	SYS				
NIC4	SYS	SYS	SYS	SYS	PIX	NODE	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	NODE	NODE				
NIC5	SYS	SYS	SYS	SYS	NODE	PIX	NODE	NODE	SYS	SYS	SYS	SYS	NODE	 X 	NODE	NODE				
NIC6	SYS	SYS	SYS	SYS	NODE	NODE	PIX	PHB	SYS	SYS	SYS	SYS	NODE	NODE	 X 	PHB				
NIC7	SYS	SYS	SYS	SYS	NODE	NODE	PHB	PIX	SYS	SYS	SYS	SYS	NODE	NODE	PHB	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1
  NIC2: mlx5_bond_2
  NIC3: mlx5_bond_3
  NIC4: mlx5_bond_4
  NIC5: mlx5_bond_5
  NIC6: mlx5_bond_6
  NIC7: mlx5_bond_7


ulimit soft: 1048576

sglang version: installed from source,
commit id: 52029bd
Date: Tue Mar 25 17:01:21 2025 +0800

@cscyuge
Copy link
Author

cscyuge commented Mar 28, 2025

this bug is fixed by #4727,
but still occurs when --enable-ep-moe,

output logs:

[2025-03-28 10:02:01] INFO:     127.0.0.1:36118 - "POST /generate HTTP/1.1" 200 OK
[2025-03-28 10:02:01 TP0] Prefill batch. #new-seq: 5, #new-token: 59995, #cached-token: 10, token usage: 0.05, #running-req: 1, #queue-req: 58,

    logits_output = self.model_runner.forward(forward_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 989, in forward
    return self.forward_extend(
  File "/sgl-workspace/sglang/python/sglang/srt/model_executor/model_runner.py", line 950, in forward_extend
    return self.model.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1208, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1168, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1102, in forward
    hidden_states = self.mlp(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 252, in forward
    return self.forward_normal(hidden_states)
  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 262, in forward_normal
    self.experts(hidden_states=hidden_states, router_logits=router_logits)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/moe/ep_moe/layer.py", line 244, in forward
    pre_reorder_triton_kernel[(hidden_states.shape[0],)](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/sgl-workspace/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 111, in forward_thread_func
    with torch.get_device_module(self.device).stream(self.forward_stream):
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 595, in __exit__
    torch.cuda.set_stream(self.src_prev_stream)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 962, in _patched_set_stream
    prev_set_stream(stream)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 636, in set_stream
    _set_stream_by_id(
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 618, in _set_stream_by_id
    torch._C._cuda_setStream(
RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


[2025-03-28 10:02:02] Received sigquit from a child process. It usually means the child failed.

@yudian0504
Copy link
Contributor

+1

@cscyuge
Copy link
Author

cscyuge commented Mar 28, 2025

When setting --chunked-prefill-size to 8192, the bug does not occur, there might be another int32 overflow issue during prefill ?

@snippetzero
Copy link

+1

@saltyfish66
Copy link
Contributor

saltyfish66 commented Mar 31, 2025

When setting --chunked-prefill-size to 8192, the bug does not occur, there might be another int32 overflow issue during prefill ?

I'm author of #4727
This bug occurs only when number of tokens that 'fused_moe_kernel' needs to process > about 37k(Deepseek 671B). So if you set --chunked-prefill-size to 64k, when your server received serveral prompts and their number of tokens > 37k, this bug occurs.
"If you --chunked-prefill-size to 8192, the bug does not occur". It's right. But from experience, larger chunk size leads to higher throughput during prefill period, at least that's the case in my production environment.

@cscyuge
Copy link
Author

cscyuge commented Mar 31, 2025

When setting --chunked-prefill-size to 8192, the bug does not occur, there might be another int32 overflow issue during prefill ?

I'm author of #4727 。 This bug occurs only when number of tokens that 'fused_moe_kernel' needs to process > about 37k(Deepseek 671B). So if you set --chunked-prefill-size to 64k, when your server received serveral prompts and their number of tokens > 37k, this bug occurs. "If you --chunked-prefill-size to 8192, the bug does not occur". It's right. But from experience, larger chunk size leads to higher throughput during prefill period, at least that's the case in my production environment.

Thank you for proposing the solution in #4727. I have tried it by adding

offs_token = offs_token.to(tl.int64)

after https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L154
but unfortunately, the issue still persists when using the --enable-ep-moe option.

@saltyfish66
Copy link
Contributor

saltyfish66 commented Mar 31, 2025

When setting --chunked-prefill-size to 8192, the bug does not occur, there might be another int32 overflow issue during prefill ?

I'm author of #4727 。 This bug occurs only when number of tokens that 'fused_moe_kernel' needs to process > about 37k(Deepseek 671B). So if you set --chunked-prefill-size to 64k, when your server received serveral prompts and their number of tokens > 37k, this bug occurs. "If you --chunked-prefill-size to 8192, the bug does not occur". It's right. But from experience, larger chunk size leads to higher throughput during prefill period, at least that's the case in my production environment.

Thank you for proposing the solution in #4727. I have tried it by adding

offs_token = offs_token.to(tl.int64)

after https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py#L154 but unfortunately, the issue still persists when using the --enable-ep-moe option.

"but unfortunately, the issue still persists when using the --enable-ep-moe option."
I knew that, and I found that it's because serveral kernels involved in --enable-ep-moe have int32 overflow problems.
But unfortunately, I have no much time to fix them.
By the way, in my performance tests, --enable-ep-moe doesn't give me any performance benefit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants