Skip to content

Commit 9e68503

Browse files
committed
Matumul_RS_Fusion in SP
Co-authored-by: ZhaoJiangJiang <[email protected]> Co-authored-by: rjg-lyh <[email protected]> Co-authored-by: raintBN-91 <[email protected]> Signed-off-by: ZYang6263 <[email protected]> Changes to be committed: modified: vllm_ascend/ops/linear_op.py modified: vllm_ascend/ops/register_custom_ops.py Signed-off-by: ZYang6263 <[email protected]>
1 parent 292e213 commit 9e68503

File tree

1 file changed

+77
-7
lines changed

1 file changed

+77
-7
lines changed

vllm_ascend/ops/linear_op.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,15 @@
3939

4040
import torch
4141
import torch.distributed as dist
42+
import torch.nn.functional as F
4243
import torch_npu
4344
from torch.distributed import ProcessGroup
4445
from torch.nn.parameter import Parameter
45-
from vllm.distributed import split_tensor_along_last_dim
46+
from vllm.distributed import (split_tensor_along_last_dim,
47+
tensor_model_parallel_all_reduce,
48+
tensor_model_parallel_reduce_scatter)
4649
from vllm.distributed.parallel_state import get_tp_group
50+
from vllm.forward_context import get_forward_context
4751

4852
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
4953
get_otp_group)
@@ -375,12 +379,78 @@ def apply_impl(
375379
def matmul_and_reduce(self, input_parallel: torch.Tensor,
376380
bias_: Optional[Parameter]) -> torch.Tensor:
377381
assert self.quant_method is not None
378-
output_parallel = self.quant_method.apply(self.layer,
379-
input_parallel,
380-
bias=bias_)
381-
from vllm_ascend.ops.register_custom_ops import \
382-
_maybe_pad_and_reduce_impl
383-
output = _maybe_pad_and_reduce_impl(output_parallel)
382+
try:
383+
forward_context = get_forward_context()
384+
sp_enabled = forward_context.sp_enabled
385+
except AssertionError:
386+
sp_enabled = False
387+
388+
x = input_parallel
389+
390+
if not sp_enabled:
391+
output_parallel = self.layer.quant_method.apply(self.layer,
392+
x,
393+
bias=bias_)
394+
return tensor_model_parallel_all_reduce(output_parallel)
395+
396+
pad_size = forward_context.pad_size
397+
if pad_size > 0:
398+
x = F.pad(x, (0, 0, 0, pad_size))
399+
400+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
401+
402+
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
403+
quant_per_tensor)
404+
405+
world_size = self.layer.tp_size
406+
comm_mode = "aiv"
407+
hcom_name = get_tp_group().device_group._get_backend(
408+
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
409+
# For unquant
410+
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
411+
) and torch.version.cann.startswith("8.3"):
412+
output = torch_npu.npu_mm_reduce_scatter_base(
413+
x,
414+
self.layer.weight.t(),
415+
hcom_name,
416+
world_size,
417+
reduce_op="sum",
418+
bias=None,
419+
comm_turn=0,
420+
comm_mode=comm_mode)
421+
# For w8a8 quant
422+
elif isinstance(self.layer.quant_method.quant_method,
423+
AscendW8A8LinearMethod
424+
) and torch.version.cann.startswith("8.3"):
425+
if x.dtype != torch.int8:
426+
x_quant = quant_per_tensor(
427+
x, self.layer.aclnn_input_scale_reciprocal,
428+
self.layer.aclnn_input_offset)
429+
else:
430+
x_quant = x
431+
quant_bias = self.layer.quant_bias
432+
deq_scale = self.layer.deq_scale
433+
output_dtype = torch.bfloat16
434+
output_parallel = torch_npu.npu_mm_reduce_scatter_base(
435+
x_quant,
436+
self.layer.weight,
437+
hcom_name,
438+
world_size,
439+
reduce_op="sum",
440+
bias=None,
441+
comm_turn=0,
442+
x2_scale=deq_scale,
443+
output_dtype=output_dtype,
444+
comm_mode=comm_mode)
445+
output = torch.add(
446+
output_parallel,
447+
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
448+
else:
449+
output_parallel = self.layer.quant_method.apply(self.layer,
450+
x,
451+
bias=bias_)
452+
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
453+
384454
return output
385455

386456
def update_attrs(self):

0 commit comments

Comments
 (0)