Skip to content

Conversation

@wxsIcey
Copy link
Collaborator

@wxsIcey wxsIcey commented Nov 13, 2025

What this PR does / why we need it?

Adopt inductor fusion and define quantization fusion pass
refer to vllm-project/vllm#23612
needs vllm-project/vllm#28623

Does this PR introduce any user-facing change?

Yes, add new additional_config

How was this patch tested?

def main():
    prompts = [
        "The president of the United States is Mr.",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
    # Create an LLM.
    llm = LLM(
        model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8",
              # enforce_eager=True,
              tensor_parallel_size=1,
              trust_remote_code=True,
              gpu_memory_utilization=0.7,
              quantization="ascend",
              )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Signed-off-by: Icey <[email protected]>
Signed-off-by: Icey <[email protected]>
Signed-off-by: Icey <[email protected]>
Signed-off-by: Icey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
Signed-off-by: wxsIcey <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: wxsIcey <[email protected]>
@wxsIcey wxsIcey changed the title [wip] Adopt inductor fusion and define quantization fusion pass Adopt inductor fusion and define quantization fusion pass Nov 13, 2025
@wxsIcey wxsIcey marked this pull request as ready for review November 13, 2025 12:58
@wxsIcey
Copy link
Collaborator Author

wxsIcey commented Nov 13, 2025

Currently, operator fusion has been achieved through pattern matching using inductors. Using aot-autograd could be a future work, but it has been found that using aot-autograd causes accuracy issues. @whx-sjtu Would you be willing to review it?

Copy link
Collaborator

@whx-sjtu whx-sjtu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. Finally we make it to utilize pattern_matcher of inductor to fuse our add_rms_norm_quant kernel into Fx graph. The whole idea looks good to me with some questions about details as reviewed following.

return shape_list


class AscendAdaptor(CompilerInterface):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name AscendAdaptor is too vague; I suggest a more specific one like AscendCompiler.

Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, 1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of fixed to 1e-6, the eps should be defined as a static variable of AddRMSNormQuantPattern, with different values of eps corresponding to different pattern objects. Some models might use different eps like 1e-5.


def __init__(self, vllm_config):
super().__init__(vllm_config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of self.patterns is a bit confusing here. It should be named as something like self.pattern_match_pass.

arg_dtypes, list) and len(arg_dtypes) > 0 else arg_dtypes
# We found that the kernel npu_add_rms_norm_quant accept varying data format for different dtypes, therefore, we only
# provide the solution on bfloat16 here.
return dtype in (torch.bfloat16, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quiet understand here. Does the format of data also influence pattern matching? Maybe we can define patterns separately for bf16 and fp16 to support them both?

Copy link
Collaborator

@whx-sjtu whx-sjtu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another question here. With current proposal can we reuse the ready-made fusion passes defined in vLLM, like the SequenceParallel Fusion Pass. Because I'm not very familiar with the stack of the current Fusion pass in vLLM, I'm confirming it here. Reusability is what we expect.

@whx-sjtu
Copy link
Collaborator

This feature is very important for vllm-ascend. I also hope @jgong5 can take some time to review this PR. Thanks.

@wxsIcey
Copy link
Collaborator Author

wxsIcey commented Nov 13, 2025

I have another question here. With current proposal can we reuse the ready-made fusion passes defined in vLLM, like the SequenceParallel Fusion Pass. Because I'm not very familiar with the stack of the current Fusion pass in vLLM, I'm confirming it here. Reusability is what we expect.

Thank you for your reply. The current PR aims to define our own compiler backend to implement custom fusion. Reusing fusion passes in VLLM is my next goal. I will submit an RFC once the solution is finalized.

@wxsIcey wxsIcey requested a review from jgong5 November 13, 2025 13:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants