Skip to content

Commit f4cdbbc

Browse files
committed
Matumul_RS_Fusion in SP
Co-authored-by: ZhaoJiangJiang <[email protected]> Co-authored-by: rjg-lyh <[email protected]> Co-authored-by: raintBN-91 <[email protected]> Signed-off-by: ZYang6263 <[email protected]> Changes to be committed: modified: vllm_ascend/ops/linear_op.py modified: vllm_ascend/ops/register_custom_ops.py Signed-off-by: ZYang6263 <[email protected]>
1 parent 292e213 commit f4cdbbc

File tree

1 file changed

+81
-7
lines changed

1 file changed

+81
-7
lines changed

vllm_ascend/ops/linear_op.py

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@
3939

4040
import torch
4141
import torch.distributed as dist
42+
import torch.nn.functional as F
4243
import torch_npu
4344
from torch.distributed import ProcessGroup
4445
from torch.nn.parameter import Parameter
45-
from vllm.distributed import split_tensor_along_last_dim
46+
from vllm.distributed import (split_tensor_along_last_dim,
47+
tensor_model_parallel_all_reduce,
48+
tensor_model_parallel_reduce_scatter)
4649
from vllm.distributed.parallel_state import get_tp_group
50+
from vllm.forward_context import get_forward_context
4751

4852
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
4953
get_otp_group)
@@ -375,12 +379,82 @@ def apply_impl(
375379
def matmul_and_reduce(self, input_parallel: torch.Tensor,
376380
bias_: Optional[Parameter]) -> torch.Tensor:
377381
assert self.quant_method is not None
378-
output_parallel = self.quant_method.apply(self.layer,
379-
input_parallel,
380-
bias=bias_)
381-
from vllm_ascend.ops.register_custom_ops import \
382-
_maybe_pad_and_reduce_impl
383-
output = _maybe_pad_and_reduce_impl(output_parallel)
382+
try:
383+
forward_context = get_forward_context()
384+
sp_enabled = forward_context.sp_enabled
385+
except AssertionError:
386+
sp_enabled = False
387+
388+
x = input_parallel
389+
390+
if not sp_enabled:
391+
output_parallel = self.layer.quant_method.apply(self.layer,
392+
x,
393+
bias=bias_)
394+
return tensor_model_parallel_all_reduce(output_parallel)
395+
396+
pad_size = forward_context.pad_size
397+
if pad_size > 0:
398+
x = F.pad(x, (0, 0, 0, pad_size))
399+
400+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
401+
402+
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
403+
quant_per_tensor)
404+
405+
# For unquant
406+
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
407+
) and torch.version.cann.startswith("8.3"):
408+
hcom_name = get_tp_group().device_group._get_backend(
409+
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
410+
world_size = self.layer.tp_size
411+
comm_mode = "aiv"
412+
output = torch_npu.npu_mm_reduce_scatter_base(
413+
x,
414+
self.layer.weight.t(),
415+
hcom_name,
416+
world_size,
417+
reduce_op="sum",
418+
bias=None,
419+
comm_turn=0,
420+
comm_mode=comm_mode)
421+
# For w8a8 quant
422+
elif isinstance(self.layer.quant_method.quant_method,
423+
AscendW8A8LinearMethod
424+
) and torch.version.cann.startswith("8.3"):
425+
if x.dtype != torch.int8:
426+
x_quant = quant_per_tensor(
427+
x, self.layer.aclnn_input_scale_reciprocal,
428+
self.layer.aclnn_input_offset)
429+
else:
430+
x_quant = x
431+
quant_bias = self.layer.quant_bias
432+
hcom_name = get_tp_group().device_group._get_backend(
433+
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
434+
world_size = self.layer.tp_size
435+
deq_scale = self.layer.deq_scale
436+
output_dtype = torch.bfloat16
437+
comm_mode = "aiv"
438+
output_parallel = torch_npu.npu_mm_reduce_scatter_base(
439+
x_quant,
440+
self.layer.weight,
441+
hcom_name,
442+
world_size,
443+
reduce_op="sum",
444+
bias=None,
445+
comm_turn=0,
446+
x2_scale=deq_scale,
447+
output_dtype=output_dtype,
448+
comm_mode=comm_mode)
449+
output = torch.add(
450+
output_parallel,
451+
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
452+
else:
453+
output_parallel = self.layer.quant_method.apply(self.layer,
454+
x,
455+
bias=bias_)
456+
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
457+
384458
return output
385459

386460
def update_attrs(self):

0 commit comments

Comments
 (0)