|
39 | 39 |
|
40 | 40 | import torch |
41 | 41 | import torch.distributed as dist |
| 42 | +import torch.nn.functional as F |
42 | 43 | import torch_npu |
43 | 44 | from torch.distributed import ProcessGroup |
44 | 45 | from torch.nn.parameter import Parameter |
45 | | -from vllm.distributed import split_tensor_along_last_dim |
| 46 | +from vllm.distributed import (split_tensor_along_last_dim, |
| 47 | + tensor_model_parallel_all_reduce, |
| 48 | + tensor_model_parallel_reduce_scatter) |
46 | 49 | from vllm.distributed.parallel_state import get_tp_group |
| 50 | +from vllm.forward_context import get_forward_context |
47 | 51 |
|
48 | 52 | from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, |
49 | 53 | get_otp_group) |
@@ -375,12 +379,82 @@ def apply_impl( |
375 | 379 | def matmul_and_reduce(self, input_parallel: torch.Tensor, |
376 | 380 | bias_: Optional[Parameter]) -> torch.Tensor: |
377 | 381 | assert self.quant_method is not None |
378 | | - output_parallel = self.quant_method.apply(self.layer, |
379 | | - input_parallel, |
380 | | - bias=bias_) |
381 | | - from vllm_ascend.ops.register_custom_ops import \ |
382 | | - _maybe_pad_and_reduce_impl |
383 | | - output = _maybe_pad_and_reduce_impl(output_parallel) |
| 382 | + try: |
| 383 | + forward_context = get_forward_context() |
| 384 | + sp_enabled = forward_context.sp_enabled |
| 385 | + except AssertionError: |
| 386 | + sp_enabled = False |
| 387 | + |
| 388 | + x = input_parallel |
| 389 | + |
| 390 | + if not sp_enabled: |
| 391 | + output_parallel = self.layer.quant_method.apply(self.layer, |
| 392 | + x, |
| 393 | + bias=bias_) |
| 394 | + return tensor_model_parallel_all_reduce(output_parallel) |
| 395 | + |
| 396 | + pad_size = forward_context.pad_size |
| 397 | + if pad_size > 0: |
| 398 | + x = F.pad(x, (0, 0, 0, pad_size)) |
| 399 | + |
| 400 | + from vllm.model_executor.layers.linear import UnquantizedLinearMethod |
| 401 | + |
| 402 | + from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod, |
| 403 | + quant_per_tensor) |
| 404 | + |
| 405 | + # For unquant |
| 406 | + if isinstance(self.layer.quant_method, UnquantizedLinearMethod |
| 407 | + ) and torch.version.cann.startswith("8.3"): |
| 408 | + hcom_name = get_tp_group().device_group._get_backend( |
| 409 | + torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank) |
| 410 | + world_size = self.layer.tp_size |
| 411 | + comm_mode = "aiv" |
| 412 | + output = torch_npu.npu_mm_reduce_scatter_base( |
| 413 | + x, |
| 414 | + self.layer.weight.t(), |
| 415 | + hcom_name, |
| 416 | + world_size, |
| 417 | + reduce_op="sum", |
| 418 | + bias=None, |
| 419 | + comm_turn=0, |
| 420 | + comm_mode=comm_mode) |
| 421 | + # For w8a8 quant |
| 422 | + elif isinstance(self.layer.quant_method.quant_method, |
| 423 | + AscendW8A8LinearMethod |
| 424 | + ) and torch.version.cann.startswith("8.3"): |
| 425 | + if x.dtype != torch.int8: |
| 426 | + x_quant = quant_per_tensor( |
| 427 | + x, self.layer.aclnn_input_scale_reciprocal, |
| 428 | + self.layer.aclnn_input_offset) |
| 429 | + else: |
| 430 | + x_quant = x |
| 431 | + quant_bias = self.layer.quant_bias |
| 432 | + hcom_name = get_tp_group().device_group._get_backend( |
| 433 | + torch.device('npu')).get_hccl_comm_name(self.layer.tp_rank) |
| 434 | + world_size = self.layer.tp_size |
| 435 | + deq_scale = self.layer.deq_scale |
| 436 | + output_dtype = torch.bfloat16 |
| 437 | + comm_mode = "aiv" |
| 438 | + output_parallel = torch_npu.npu_mm_reduce_scatter_base( |
| 439 | + x_quant, |
| 440 | + self.layer.weight, |
| 441 | + hcom_name, |
| 442 | + world_size, |
| 443 | + reduce_op="sum", |
| 444 | + bias=None, |
| 445 | + comm_turn=0, |
| 446 | + x2_scale=deq_scale, |
| 447 | + output_dtype=output_dtype, |
| 448 | + comm_mode=comm_mode) |
| 449 | + output = torch.add( |
| 450 | + output_parallel, |
| 451 | + torch.mul(quant_bias, deq_scale).to(self.layer.params_dtype)) |
| 452 | + else: |
| 453 | + output_parallel = self.layer.quant_method.apply(self.layer, |
| 454 | + x, |
| 455 | + bias=bias_) |
| 456 | + output = tensor_model_parallel_reduce_scatter(output_parallel, 0) |
| 457 | + |
384 | 458 | return output |
385 | 459 |
|
386 | 460 | def update_attrs(self): |
|
0 commit comments