Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def __init__(self, vllm_config):
{})
self.torchair_graph_config = TorchairGraphConfig(torchair_graph_config)

ascend_compilation_config = additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config = AscendCompilationConfig(
**ascend_compilation_config)

ascend_scheduler_config = additional_config.get(
"ascend_scheduler_config", {})
self.ascend_scheduler_config = AscendSchedulerConfig(
Expand Down Expand Up @@ -128,6 +133,22 @@ def __init__(self, vllm_config):
"Only support P node tp size lagger then D node tp size")


class AscendCompilationConfig:
"""
Configuration Object for ascend_compilation_config from additional_config
"""

def __init__(self,
enable_graph_fusion: bool = True,
fx_graph_eager: bool = False,
enable_quantization_fusion: bool = True,
**kwargs):
self.enable_graph_fusion = enable_graph_fusion
self.fx_graph_eager = fx_graph_eager
self.enable_quantization_fusion = enable_quantization_fusion
# Add more compilation related configs here as needed


class TorchairGraphConfig:
"""
Configuration Object for torchair_graph_config from additional_config
Expand Down Expand Up @@ -279,6 +300,17 @@ def check_ascend_config(vllm_config, enforce_eager):
"it has been disabled automatically.")
# aclgraph case
else:
# This graph fusion can actually works on eager mode.
if ascend_config.ascend_compilation_config.enable_graph_fusion:
logger.info(
"graph fusion enabled! Automatic kernel fusion is expected."
)

if ascend_config.ascend_compilation_config.enable_quantization_fusion:
logger.info(
"Quantization fusion enabled! op fusion on quantization are expected. "
)

if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if "qwen" not in model_type:
Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def __init__(self,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.ascend_compilation_config: dict = vllm_config.additional_config.get(
"ascend_compilation_config", {})
self.fx_graph_eager = self.ascend_compilation_config.get(
"fx_graph_eager", False)
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
Expand Down Expand Up @@ -103,7 +107,8 @@ def __call__(self, *args, **kwargs):
aclgraph_runtime_mode = forward_context.cudagraph_runtime_mode

if aclgraph_runtime_mode == CUDAGraphMode.NONE or \
aclgraph_runtime_mode != self.runtime_mode:
aclgraph_runtime_mode != self.runtime_mode or \
self.fx_graph_eager:
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
# running without aclgraphs.
# We do not trigger capture/replay if the runtime mode is not
Expand Down
72 changes: 72 additions & 0 deletions vllm_ascend/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Callable, Optional

import torch
import torch.fx as fx
import torch.utils._pytree as pytree
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.utils import output_node
from vllm.compilation.compiler_interface import CompilerInterface


def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]:
"""
Extract the dtype from the kwargs dictionary.
"""
dtype_list = []
for value in args:
if isinstance(value, torch.Tensor):
dtype_list.append(value.dtype)
return dtype_list


def get_shapes_from_args(args: list[Any]) -> list[torch.Size]:
"""
Extract the shapes from the kwargs dictionary.
"""
shape_list = []
for value in args:
if isinstance(value, torch.Tensor):
shape_list.append(value.shape)
return shape_list


class AscendAdaptor(CompilerInterface):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name AscendAdaptor is too vague; I suggest a more specific one like AscendCompiler.

name = "AscendAdaptor"

def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:

current_pass_manager = compiler_config["graph_fusion_manager"]
arg_dtypes = get_dtype_from_args(example_inputs)
arg_shapes = get_shapes_from_args(example_inputs)
kwargs = {
"runtime_shape": runtime_shape,
"arg_shapes": arg_shapes,
"arg_dtypes": arg_dtypes
}
graph = current_pass_manager(graph, **kwargs)
return graph, None
54 changes: 54 additions & 0 deletions vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from torch import fx as fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig


class GraphFusionPassManager:
"""
A pass manager for graph rewriting passes.
It handles the configuration and execution of passes.
The counterpart in vllm is PostGradPassManager. Since torch_npu does not
support inductor and triton for now, we choose to adopt the graph rewriter on
fx graph rather than the inductor pass manager.
"""

def __init__(self):
self.passes: list[VllmInductorPass] = []

def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph:
for pass_ in self.passes:
if pass_.is_applicable(**kwargs):
pass_(graph)
return graph

def add(self, pass_: VllmInductorPass):
assert isinstance(pass_, VllmInductorPass)
self.passes.append(pass_)

def configure(self, config: VllmConfig):
# By default, we enable the graph rewriter and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {})
if self.ascend_compilation_config.get("enable_quantization_fusion",
True):
from .quant_fusion_pass import AscendQuantFusionPass
self.passes.append(AscendQuantFusionPass(config))
# Add more passes here as needed
104 changes: 104 additions & 0 deletions vllm_ascend/compilation/quant_fusion_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass


class AddRMSNormQuantPattern:

def __init__(self, vllm_config):
self.vllm_config = vllm_config

def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.tensor([1.0], device="npu")
offset = torch.tensor([0.0], device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset]

def register(self, pm_pass: PatternMatcherPass):

def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, 1e-6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of fixed to 1e-6, the eps should be defined as a static variable of AddRMSNormQuantPattern, with different values of eps corresponding to different pattern objects. Some models might use different eps like 1e-5.

out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(
out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1

def replacement(rms_norm_input, residual, rms_norm_weight, scale,
offset):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
1. /
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
offset,
epsilon=1e-6)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1

pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)


class AscendQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
"""

def __init__(self, vllm_config):
super().__init__(vllm_config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of self.patterns is a bit confusing here. It should be named as something like self.pattern_match_pass.

pass_name="rmsnorm_quant_fusion_pass")
AddRMSNormQuantPattern(vllm_config).register(self.patterns)

def __call__(self, graph: torch.fx.Graph):
self.begin()
matched_count = self.patterns.apply(graph)
self.end_and_log()

def is_applicable(self, **kwargs):
"""
Check if the pass is applicable for the current configuration.
"""
arg_dtypes = kwargs.get("arg_dtypes", None)
if arg_dtypes is None:
return False
# We assume the first tensor's dtype is the data type of this model, update this solution when there is
# better solution.
dtype = arg_dtypes[0] if isinstance(
arg_dtypes, list) and len(arg_dtypes) > 0 else arg_dtypes
# We found that the kernel npu_add_rms_norm_quant accept varying data format for different dtypes, therefore, we only
# provide the solution on bfloat16 here.
return dtype in (torch.bfloat16, )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quiet understand here. Does the format of data also influence pattern matching? Maybe we can define patterns separately for bf16 and fp16 to support them both?

Loading
Loading