Skip to content

Conversation

@k50112113
Copy link

@k50112113 k50112113 commented Jan 9, 2026

This PR is co-authored by @k50112113, @omuhamma (#61) and @farlukas (#116)

This PR provides Triton fusion/GEMM optimizations for DS FP4 and FP8,
please use the following AITER branch for testing for now as some of the PRs are yet to be merged to AITER main:
https://github.com/ROCm/aiter/tree/shaoclee/atom_triton_tmp_0106

The required AITER PRs include:

  1. [Triton] Triton A16WFP4 GEMM prequant aiter#1777
  2. [Triton] Triton a16w8 gemm preshuffle aiter#1778
  3. [Triton] Add Fused GEMM A8W8 + Split + Concat Triton Kernel aiter#1553 (review)

To activate the optimizations on ATOM, the following env variables are required:

# for concurrency > 4, use AR + RMS_Quant + GEMM optimizations:
export ATOM_USE_TRITON_GEMM=1
# note: ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION is turned on automictically when ATOM_USE_TRITON_GEMM is on

# for concurrency = 4, use AR_RMS + Quant_GEMM optimizations:
export ATOM_USE_TRITON_GEMM=1
export ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION=0

The following command along with the above env var are used to derive e2e performance results:

# for DS FP8
python -m atom.entrypoints.openai_server \
    --model /data/deepseek-ai/DeepSeek-R1-0528/ \
    -tp 8 \
    --block-size 1 \
    --server-port 8989 2>&1 | tee server.out

# for DS FP4
export ATOM_USE_TRITON_MXFP4_BMM=1
export AMDGCN_USE_BUFFER_OPS=1
python -m atom.entrypoints.openai_server \
    --model /data/DeepSeek-R1-0528-MXFP4-Preview \
    -tp 8 \
    --block-size 16 \
    --kv_cache_dtype fp8 \
    --server-port 8989 \
    2>&1 | tee server.out

For client command:

MODEL=<DS FP4 or FP8 model paths>
ISL=3500
OSL=1500
PORT=8989
for CONC in 4 256 128 64 32 16 8; do
    RESULT_FILENAME=${ISL}_${OSL}_${CONC}
    python /root/ATOM/atom/benchmarks/benchmark_serving.py \
        --model=$MODEL --backend=vllm --base-url=http://localhost:$PORT \
        --dataset-name=random \
        --random-input-len=$ISL --random-output-len=$OSL \
        --random-range-ratio 1.0 \
        --num-prompts=$(( $CONC * 8 )) \
        --max-concurrency=$CONC \
        --request-rate=inf --ignore-eos \
        --save-result --percentile-metrics="ttft,tpot,itl,e2el" \
        --result-dir=./ --result-filename=$RESULT_FILENAME.json 2>&1 | tee -a ${RESULT_FILENAME}.log
done

DS FP8 performance comparisons and uplift
image

DS FP4 performance comparisons and uplift
image

k50112113 and others added 30 commits December 11, 2025 19:32
Making the BMM use fp4 weights
@k50112113 k50112113 requested a review from valarLip January 9, 2026 15:47
Copy link
Collaborator

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, approved for benchmark testing.

# For Triton FP8 Blockscale GEMM is mostly slower then AITER GEMM, we turn off Triton FP8 GEMM
# from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle as gemm_a8w8_blockscale_bpreshuffle_triton
except:
gemm_afp4wfp4_preshuffle = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Use specific exceptions and add logging:

if use_triton_gemm():
try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle
except ImportError as e:
logger.warning(f"Triton FP4 GEMM not available: {e}")
gemm_afp4wfp4_preshuffle = None

try:
from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat
except:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if use_triton_gemm():
try:
from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat
except ImportError as e:
logger.debug(f"Triton fused GEMM split_cat not available: {e}")
fused_gemm_afp4wfp4_preshuffle_split_cat = None
fused_gemm_a8w8_blockscale_preshuffle_split_cat = None

from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle
from aiter.ops.triton.gemm_a16w8_blockscale import gemm_a16w8_blockscale_preshuffle
except:
gemm_afp4wfp4_preshuffle = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add logger?
logger.warning(f"Triton GEMM kernels not available: {e}. Ensure AITER is up-to-date.")

shuffle=(m >= 32),
)

if m >= 32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use module constant?

In both files, import or define:

from atom.models.deepseek_v2 import MXFP4_QUANT_BLOCK_SIZE

Then use:

if m >= MXFP4_QUANT_BLOCK_SIZE:
x_scale = x_scale.view(torch.uint8).view(x_scale.shape[0] // MXFP4_QUANT_BLOCK_SIZE, -1)

# return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)


def mxfp4_to_f32(x, is_threed):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated with the one in aiter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants