Skip to content

Commit 8a304b6

Browse files
committed
Matumul_RS_Fusion in SP
Co-authored-by: ZhaoJiangJiang <[email protected]> Co-authored-by: rjg-lyh <[email protected]> Co-authored-by: ZhaoJiangJiang <[email protected]> Signed-off-by: ZYang6263 <[email protected]> Changes to be committed: modified: vllm_ascend/ops/linear_op.py modified: vllm_ascend/ops/register_custom_ops.py
1 parent 292e213 commit 8a304b6

File tree

1 file changed

+86
-6
lines changed

1 file changed

+86
-6
lines changed

vllm_ascend/ops/linear_op.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,92 @@ def apply_impl(
375375
def matmul_and_reduce(self, input_parallel: torch.Tensor,
376376
bias_: Optional[Parameter]) -> torch.Tensor:
377377
assert self.quant_method is not None
378-
output_parallel = self.quant_method.apply(self.layer,
379-
input_parallel,
380-
bias=bias_)
381-
from vllm_ascend.ops.register_custom_ops import \
382-
_maybe_pad_and_reduce_impl
383-
output = _maybe_pad_and_reduce_impl(output_parallel)
378+
try:
379+
forward_context = get_forward_context()
380+
sp_enabled = forward_context.sp_enabled
381+
except AssertionError:
382+
sp_enabled = False
383+
384+
x = input_parallel
385+
386+
if not sp_enabled:
387+
output_parallel = self.layer.quant_method.apply(self.layer,
388+
x,
389+
bias=bias_)
390+
return tensor_model_parallel_all_reduce(output_parallel)
391+
392+
pad_size = forward_context.pad_size
393+
if pad_size > 0:
394+
x = F.pad(x, (0, 0, 0, pad_size))
395+
396+
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
397+
398+
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
399+
quant_per_tensor)
400+
401+
# unquant
402+
if isinstance(self.layer.quant_method, UnquantizedLinearMethod
403+
) and torch.version.cann.startswith("8.3"):
404+
output_parallel = torch.empty(x.shape[0] // self.layer.tp_size,
405+
self.layer.weight.shape[0],
406+
dtype=self.layer.params_dtype,
407+
device=x.device)
408+
hcom_name = get_tp_group().device_group._get_backend(
409+
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
410+
world_size = self.layer.tp_size
411+
comm_mode = "aiv"
412+
output = torch_npu.npu_mm_reduce_scatter_base(
413+
x,
414+
self.layer.weight.t(),
415+
hcom_name,
416+
world_size,
417+
reduce_op="sum",
418+
bias=None,
419+
comm_turn=0,
420+
comm_mode=comm_mode)
421+
# w8a8 quant
422+
elif isinstance(self.layer.quant_method.quant_method,
423+
AscendW8A8LinearMethod
424+
) and torch.version.cann.startswith("8.3"):
425+
426+
if x.dtype != torch.int8:
427+
x_quant = quant_per_tensor(
428+
x, self.layer.aclnn_input_scale_reciprocal,
429+
self.layer.aclnn_input_offset)
430+
else:
431+
x_quant = x
432+
output_parallel = torch.empty(x_quant.shape[0] //
433+
self.layer.tp_size,
434+
self.layer.weight.shape[1],
435+
dtype=self.layer.params_dtype,
436+
device=x.device)
437+
quant_bias = self.layer.quant_bias
438+
hcom_name = get_tp_group().device_group._get_backend(
439+
torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank)
440+
world_size = self.layer.tp_size
441+
deq_scale = self.layer.deq_scale
442+
output_dtype = torch.bfloat16
443+
comm_mode = "aiv"
444+
output_parallel = torch_npu.npu_mm_reduce_scatter_base(
445+
x_quant,
446+
self.layer.weight,
447+
hcom_name,
448+
world_size,
449+
reduce_op="sum",
450+
bias=None,
451+
comm_turn=0,
452+
x2_scale=deq_scale,
453+
output_dtype=output_dtype,
454+
comm_mode=comm_mode)
455+
output = torch.add(
456+
output_parallel,
457+
torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype))
458+
else:
459+
output_parallel = self.layer.quant_method.apply(self.layer,
460+
x,
461+
bias=bias_)
462+
output = tensor_model_parallel_reduce_scatter(output_parallel, 0)
463+
384464
return output
385465

386466
def update_attrs(self):

0 commit comments

Comments
 (0)