@@ -375,12 +375,92 @@ def apply_impl(
375375    def  matmul_and_reduce (self , input_parallel : torch .Tensor ,
376376                          bias_ : Optional [Parameter ]) ->  torch .Tensor :
377377        assert  self .quant_method  is  not None 
378-         output_parallel  =  self .quant_method .apply (self .layer ,
379-                                                   input_parallel ,
380-                                                   bias = bias_ )
381-         from  vllm_ascend .ops .register_custom_ops  import  \
382-             _maybe_pad_and_reduce_impl 
383-         output  =  _maybe_pad_and_reduce_impl (output_parallel )
378+         try :
379+             forward_context  =  get_forward_context ()
380+             sp_enabled  =  forward_context .sp_enabled 
381+         except  AssertionError :
382+             sp_enabled  =  False 
383+ 
384+         x  =  input_parallel 
385+ 
386+         if  not  sp_enabled :
387+             output_parallel  =  self .layer .quant_method .apply (self .layer ,
388+                                                             x ,
389+                                                             bias = bias_ )
390+             return  tensor_model_parallel_all_reduce (output_parallel )
391+ 
392+         pad_size  =  forward_context .pad_size 
393+         if  pad_size  >  0 :
394+             x  =  F .pad (x , (0 , 0 , 0 , pad_size ))
395+ 
396+         from  vllm .model_executor .layers .linear  import  UnquantizedLinearMethod 
397+ 
398+         from  vllm_ascend .quantization .w8a8  import  (AscendW8A8LinearMethod ,
399+                                                    quant_per_tensor )
400+         
401+         # unquant 
402+         if  isinstance (self .layer .quant_method , UnquantizedLinearMethod 
403+                       ) and  torch .version .cann .startswith ("8.3" ):
404+             output_parallel  =  torch .empty (x .shape [0 ] //  self .layer .tp_size ,
405+                                           self .layer .weight .shape [0 ],
406+                                           dtype = self .layer .params_dtype ,
407+                                           device = x .device )
408+             hcom_name  =  get_tp_group ().device_group ._get_backend (
409+                 torch .device ('npu' )).get_hccl_comm_name (self .layer .tp_rank )
410+             world_size  =  self .layer .tp_size 
411+             comm_mode  =  "aiv" 
412+             output  =  torch_npu .npu_mm_reduce_scatter_base (
413+                 x ,
414+                 self .layer .weight .t (),
415+                 hcom_name ,
416+                 world_size ,
417+                 reduce_op = "sum" ,
418+                 bias = None ,
419+                 comm_turn = 0 ,
420+                 comm_mode = comm_mode )
421+         # w8a8 quant 
422+         elif  isinstance (self .layer .quant_method .quant_method ,
423+                         AscendW8A8LinearMethod 
424+                         ) and  torch .version .cann .startswith ("8.3" ):
425+             
426+             if  x .dtype  !=  torch .int8 :
427+                 x_quant  =  quant_per_tensor (
428+                     x , self .layer .aclnn_input_scale_reciprocal ,
429+                     self .layer .aclnn_input_offset )
430+             else :
431+                 x_quant  =  x 
432+             output_parallel  =  torch .empty (x_quant .shape [0 ] // 
433+                                           self .layer .tp_size ,
434+                                           self .layer .weight .shape [1 ],
435+                                           dtype = self .layer .params_dtype ,
436+                                           device = x .device )
437+             quant_bias  =  self .layer .quant_bias 
438+             hcom_name  =  get_tp_group ().device_group ._get_backend (
439+                 torch .device ('npu' )).get_hccl_comm_name (self .layer .tp_rank )
440+             world_size  =  self .layer .tp_size 
441+             deq_scale  =  self .layer .deq_scale 
442+             output_dtype  =  torch .bfloat16 
443+             comm_mode  =  "aiv" 
444+             output_parallel  =  torch_npu .npu_mm_reduce_scatter_base (
445+                 x_quant ,
446+                 self .layer .weight ,
447+                 hcom_name ,
448+                 world_size ,
449+                 reduce_op = "sum" ,
450+                 bias = None ,
451+                 comm_turn = 0 ,
452+                 x2_scale = deq_scale ,
453+                 output_dtype = output_dtype ,
454+                 comm_mode = comm_mode )
455+             output  =  torch .add (
456+                 output_parallel ,
457+                 torch .mul (quant_bias , deq_scale ).to (self .layer .params_dtype ))
458+         else :
459+             output_parallel  =  self .layer .quant_method .apply (self .layer ,
460+                                                             x ,
461+                                                             bias = bias_ )
462+             output  =  tensor_model_parallel_reduce_scatter (output_parallel , 0 )
463+ 
384464        return  output 
385465
386466    def  update_attrs (self ):
0 commit comments