diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index ff0462e244..ee424c4bd3 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -39,11 +39,15 @@ import torch import torch.distributed as dist +import torch.nn.functional as F import torch_npu from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from vllm.distributed import split_tensor_along_last_dim +from vllm.distributed import (split_tensor_along_last_dim, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from vllm.distributed.parallel_state import get_tp_group +from vllm.forward_context import get_forward_context from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) @@ -375,12 +379,78 @@ def apply_impl( def matmul_and_reduce(self, input_parallel: torch.Tensor, bias_: Optional[Parameter]) -> torch.Tensor: assert self.quant_method is not None - output_parallel = self.quant_method.apply(self.layer, - input_parallel, - bias=bias_) - from vllm_ascend.ops.register_custom_ops import \ - _maybe_pad_and_reduce_impl - output = _maybe_pad_and_reduce_impl(output_parallel) + try: + forward_context = get_forward_context() + sp_enabled = forward_context.sp_enabled + except AssertionError: + sp_enabled = False + + x = input_parallel + + if not sp_enabled: + output_parallel = self.layer.quant_method.apply(self.layer, + x, + bias=bias_) + return tensor_model_parallel_all_reduce(output_parallel) + + pad_size = forward_context.pad_size + if pad_size > 0: + x = F.pad(x, (0, 0, 0, pad_size)) + + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + + from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod, + quant_per_tensor) + + world_size = self.layer.tp_size + comm_mode = "aiv" + hcom_name = get_tp_group().device_group._get_backend( + torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank) + # For unquant + if isinstance(self.layer.quant_method, UnquantizedLinearMethod + ) and torch.version.cann.startswith("8.3"): + output = torch_npu.npu_mm_reduce_scatter_base( + x, + self.layer.weight.t(), + hcom_name, + world_size, + reduce_op="sum", + bias=None, + comm_turn=0, + comm_mode=comm_mode) + # For w8a8 quant + elif isinstance(self.layer.quant_method.quant_method, + AscendW8A8LinearMethod + ) and torch.version.cann.startswith("8.3"): + if x.dtype != torch.int8: + x_quant = quant_per_tensor( + x, self.layer.aclnn_input_scale_reciprocal, + self.layer.aclnn_input_offset) + else: + x_quant = x + quant_bias = self.layer.quant_bias + deq_scale = self.layer.deq_scale + output_dtype = torch.bfloat16 + output_parallel = torch_npu.npu_mm_reduce_scatter_base( + x_quant, + self.layer.weight, + hcom_name, + world_size, + reduce_op="sum", + bias=None, + comm_turn=0, + x2_scale=deq_scale, + output_dtype=output_dtype, + comm_mode=comm_mode) + output = torch.add( + output_parallel, + torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype)) + else: + output_parallel = self.layer.quant_method.apply(self.layer, + x, + bias=bias_) + output = tensor_model_parallel_reduce_scatter(output_parallel, 0) + return output def update_attrs(self):