@@ -817,14 +817,20 @@ def raise_deprecated_error(cls_name, reduce_val, size_avg_val):
817817 )
818818
819819
820- def legacy_reduction_decorator (init_func ):
820+ def legacy_reduction_decorator (
821+ init_func : Callable [_InputT , _RetT ],
822+ ) -> Callable [_InputT , _RetT ]:
821823 """
822824 Function decorator for __init__: intercept deprecated 'reduce' and 'size_average'.
823825 """
824826
825827 @functools .wraps (init_func )
826- def wrapper (self , * args , ** kwargs ):
827- cls_name = self .__class__ .__name__
828+ def wrapper (self , * args : _InputT .args , ** kwargs : _InputT .kwargs ) -> _RetT :
829+ cls_name = init_func .__qualname__ .split (
830+ "."
831+ )[
832+ 0
833+ ] # avoid subclass calling parent class init, causing cls_name to be inaccurate
828834 reduce_val , size_avg_val = get_legacy_reduce_and_size_average (
829835 cls_name , args , kwargs
830836 )
@@ -833,17 +839,21 @@ def wrapper(self, *args, **kwargs):
833839
834840 return init_func (self , * args , ** kwargs )
835841
842+ wrapper .__signature__ = inspect .signature (init_func )
836843 return wrapper
837844
838845
839- def legacy_reduction_special_decorator (init_func ):
846+ def legacy_reduction_special_decorator (
847+ init_func : Callable [_InputT , _RetT ],
848+ ) -> Callable [_InputT , _RetT ]:
840849 """
841- Specialized decorator: add CrossEntropyLoss / KLDivLoss special case judgment based on general logic.
850+ Specialized decorator: add CrossEntropyLoss / KLDivLoss special case judgment
851+ based on the general legacy_reduction_decorator logic.
842852 """
843853
844854 @functools .wraps (init_func )
845- def wrapper (self , * args , ** kwargs ) :
846- cls_name = self . __class__ . __name__
855+ def wrapper (self , * args : _InputT . args , ** kwargs : _InputT . kwargs ) -> _RetT :
856+ cls_name = init_func . __qualname__ . split ( "." )[ 0 ]
847857 reduce_val , size_avg_val = get_legacy_reduce_and_size_average (
848858 cls_name , args , kwargs
849859 )
@@ -863,4 +873,5 @@ def wrapper(self, *args, **kwargs):
863873 raise_deprecated_error (cls_name , reduce_val , size_avg_val )
864874 return init_func (self , * args , ** kwargs )
865875
876+ wrapper .__signature__ = inspect .signature (init_func )
866877 return wrapper
0 commit comments