-
Notifications
You must be signed in to change notification settings - Fork 562
Adopt inductor fusion and define quantization fusion pass #4168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f96060b
e0eba49
bba9416
a443661
34f1805
38d87cc
92c395e
0cfdf1d
d15fed5
3960ea3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from typing import Any, Callable, Optional | ||
|
|
||
| import torch | ||
| import torch.fx as fx | ||
| import torch.utils._pytree as pytree | ||
| from torch._dynamo.backends.common import aot_autograd | ||
| from torch._inductor.utils import output_node | ||
| from vllm.compilation.compiler_interface import CompilerInterface | ||
|
|
||
|
|
||
| def get_dtype_from_args(args: list[Any]) -> list[torch.dtype]: | ||
| """ | ||
| Extract the dtype from the kwargs dictionary. | ||
| """ | ||
| dtype_list = [] | ||
| for value in args: | ||
| if isinstance(value, torch.Tensor): | ||
| dtype_list.append(value.dtype) | ||
| return dtype_list | ||
|
|
||
|
|
||
| def get_shapes_from_args(args: list[Any]) -> list[torch.Size]: | ||
| """ | ||
| Extract the shapes from the kwargs dictionary. | ||
| """ | ||
| shape_list = [] | ||
| for value in args: | ||
| if isinstance(value, torch.Tensor): | ||
| shape_list.append(value.shape) | ||
| return shape_list | ||
|
|
||
|
|
||
| class AscendAdaptor(CompilerInterface): | ||
| name = "AscendAdaptor" | ||
|
|
||
| def compile( | ||
| self, | ||
| graph: fx.GraphModule, | ||
| example_inputs: list[Any], | ||
| compiler_config: dict[str, Any], | ||
| runtime_shape: Optional[int] = None, | ||
| key: Optional[str] = None, | ||
| ) -> tuple[Optional[Callable], Optional[Any]]: | ||
|
|
||
| current_pass_manager = compiler_config["graph_fusion_manager"] | ||
| arg_dtypes = get_dtype_from_args(example_inputs) | ||
| arg_shapes = get_shapes_from_args(example_inputs) | ||
| kwargs = { | ||
| "runtime_shape": runtime_shape, | ||
| "arg_shapes": arg_shapes, | ||
| "arg_dtypes": arg_dtypes | ||
| } | ||
| graph = current_pass_manager(graph, **kwargs) | ||
| return graph, None | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from torch import fx as fx | ||
| from vllm.compilation.vllm_inductor_pass import VllmInductorPass | ||
| from vllm.config import VllmConfig | ||
|
|
||
|
|
||
| class GraphFusionPassManager: | ||
| """ | ||
| A pass manager for graph rewriting passes. | ||
| It handles the configuration and execution of passes. | ||
| The counterpart in vllm is PostGradPassManager. Since torch_npu does not | ||
| support inductor and triton for now, we choose to adopt the graph rewriter on | ||
| fx graph rather than the inductor pass manager. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self.passes: list[VllmInductorPass] = [] | ||
|
|
||
| def __call__(self, graph: fx.Graph, **kwargs) -> fx.Graph: | ||
| for pass_ in self.passes: | ||
| if pass_.is_applicable(**kwargs): | ||
| pass_(graph) | ||
| return graph | ||
|
|
||
| def add(self, pass_: VllmInductorPass): | ||
| assert isinstance(pass_, VllmInductorPass) | ||
| self.passes.append(pass_) | ||
|
|
||
| def configure(self, config: VllmConfig): | ||
| # By default, we enable the graph rewriter and quantization fusion pass. | ||
| self.ascend_compilation_config: dict = config.additional_config.get( | ||
| "ascend_compilation_config", {}) | ||
| if self.ascend_compilation_config.get("enable_quantization_fusion", | ||
| True): | ||
| from .quant_fusion_pass import AscendQuantFusionPass | ||
| self.passes.append(AscendQuantFusionPass(config)) | ||
| # Add more passes here as needed |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # | ||
| # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
| # This file is a part of the vllm-ascend project. | ||
| # | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| import torch | ||
| import torch._inductor.pattern_matcher as pm | ||
| from torch._inductor.pattern_matcher import PatternMatcherPass | ||
| from vllm.compilation.vllm_inductor_pass import VllmInductorPass | ||
|
|
||
|
|
||
| class AddRMSNormQuantPattern: | ||
|
|
||
| def __init__(self, vllm_config): | ||
| self.vllm_config = vllm_config | ||
|
|
||
| def get_inputs(self): | ||
| """ | ||
| Generate example inputs for the AddRMSNormQuant fusion pattern. | ||
| """ | ||
| rms_norm_input = torch.randn(2, 4, device="npu") | ||
| residual = torch.randn(2, 4, device="npu") | ||
| rms_norm_weight = torch.randn(4, device="npu") | ||
| scale = torch.tensor([1.0], device="npu") | ||
| offset = torch.tensor([0.0], device="npu") | ||
| return [rms_norm_input, residual, rms_norm_weight, scale, offset] | ||
|
|
||
| def register(self, pm_pass: PatternMatcherPass): | ||
|
|
||
| def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset): | ||
| """ | ||
| Pattern for AddRMSNormQuant fusion. | ||
| """ | ||
| output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, | ||
| rms_norm_weight, 1e-6) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of fixed to |
||
| out0 = output[0] | ||
| out1 = output[2] | ||
| quantized_output = torch.ops.npu.npu_quantize( | ||
| out0, scale, offset, torch.qint8, -1, False) | ||
| return quantized_output, out1 | ||
|
|
||
| def replacement(rms_norm_input, residual, rms_norm_weight, scale, | ||
| offset): | ||
| """ | ||
| Replacement for the AddRMSNormQuant fusion. | ||
| """ | ||
| output = torch.ops.npu.npu_add_rms_norm_quant( | ||
| rms_norm_input, | ||
| residual, | ||
| rms_norm_weight, | ||
| 1. / | ||
| scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. | ||
| offset, | ||
| epsilon=1e-6) | ||
| quantized_output = output[0] | ||
| out1 = output[2] | ||
| return quantized_output, out1 | ||
|
|
||
| pm.register_replacement(pattern, replacement, self.get_inputs(), | ||
| pm.fwd_only, pm_pass) | ||
|
|
||
|
|
||
| class AscendQuantFusionPass(VllmInductorPass): | ||
| """ | ||
| A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. | ||
| """ | ||
|
|
||
| def __init__(self, vllm_config): | ||
| super().__init__(vllm_config) | ||
| self.patterns: PatternMatcherPass = PatternMatcherPass( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name of |
||
| pass_name="rmsnorm_quant_fusion_pass") | ||
| AddRMSNormQuantPattern(vllm_config).register(self.patterns) | ||
|
|
||
| def __call__(self, graph: torch.fx.Graph): | ||
| self.begin() | ||
| matched_count = self.patterns.apply(graph) | ||
| self.end_and_log() | ||
|
|
||
| def is_applicable(self, **kwargs): | ||
| """ | ||
| Check if the pass is applicable for the current configuration. | ||
| """ | ||
| arg_dtypes = kwargs.get("arg_dtypes", None) | ||
| if arg_dtypes is None: | ||
| return False | ||
| # We assume the first tensor's dtype is the data type of this model, update this solution when there is | ||
| # better solution. | ||
| dtype = arg_dtypes[0] if isinstance( | ||
| arg_dtypes, list) and len(arg_dtypes) > 0 else arg_dtypes | ||
| # We found that the kernel npu_add_rms_norm_quant accept varying data format for different dtypes, therefore, we only | ||
| # provide the solution on bfloat16 here. | ||
| return dtype in (torch.bfloat16, ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quiet understand here. Does the format of data also influence pattern matching? Maybe we can define patterns separately for bf16 and fp16 to support them both? |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name
AscendAdaptoris too vague; I suggest a more specific one likeAscendCompiler.