Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 77 additions & 7 deletions vllm_ascend/ops/linear_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import split_tensor_along_last_dim
from vllm.distributed import (split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import get_forward_context

from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
Expand Down Expand Up @@ -375,12 +379,78 @@ def apply_impl(
def matmul_and_reduce(self, input_parallel: torch.Tensor,
bias_: Optional[Parameter]) -> torch.Tensor:
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
from vllm_ascend.ops.register_custom_ops import \
_maybe_pad_and_reduce_impl
output = _maybe_pad_and_reduce_impl(output_parallel)
try:
forward_context = get_forward_context()
sp_enabled = forward_context.sp_enabled
except AssertionError:
sp_enabled = False

x = input_parallel

if not sp_enabled:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
return tensor_model_parallel_all_reduce(output_parallel)

pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))

from vllm.model_executor.layers.linear import UnquantizedLinearMethod

from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)

world_size = self.layer.tp_size
comm_mode = "aiv"
hcom_name = get_tp_group().device_group._get_backend(
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
# For unquant
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
) and torch.version.cann.startswith("8.3"):
output = torch_npu.npu_mm_reduce_scatter_base(
x,
self.layer.weight.t(),
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
comm_mode=comm_mode)
# For w8a8 quant
elif isinstance(self.layer.quant_method.quant_method,
AscendW8A8LinearMethod
) and torch.version.cann.startswith("8.3"):
if x.dtype != torch.int8:
x_quant = quant_per_tensor(
x, self.layer.aclnn_input_scale_reciprocal,
self.layer.aclnn_input_offset)
else:
x_quant = x
quant_bias = self.layer.quant_bias
deq_scale = self.layer.deq_scale
output_dtype = torch.bfloat16
output_parallel = torch_npu.npu_mm_reduce_scatter_base(
x_quant,
self.layer.weight,
hcom_name,
world_size,
reduce_op="sum",
bias=None,
comm_turn=0,
x2_scale=deq_scale,
output_dtype=output_dtype,
comm_mode=comm_mode)
output = torch.add(
output_parallel,
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
else:
output_parallel = self.layer.quant_method.apply(self.layer,
x,
bias=bias_)
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)

return output

def update_attrs(self):
Expand Down
Loading