Skip to content

Commit def5065

Browse files
committed
add type hint
1 parent 48cabd1 commit def5065

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

python/paddle/utils/decorator_utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -817,14 +817,20 @@ def raise_deprecated_error(cls_name, reduce_val, size_avg_val):
817817
)
818818

819819

820-
def legacy_reduction_decorator(init_func):
820+
def legacy_reduction_decorator(
821+
init_func: Callable[_InputT, _RetT],
822+
) -> Callable[_InputT, _RetT]:
821823
"""
822824
Function decorator for __init__: intercept deprecated 'reduce' and 'size_average'.
823825
"""
824826

825827
@functools.wraps(init_func)
826-
def wrapper(self, *args, **kwargs):
827-
cls_name = self.__class__.__name__
828+
def wrapper(self, *args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
829+
cls_name = init_func.__qualname__.split(
830+
"."
831+
)[
832+
0
833+
] # avoid subclass calling parent class init, causing cls_name to be inaccurate
828834
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
829835
cls_name, args, kwargs
830836
)
@@ -833,17 +839,21 @@ def wrapper(self, *args, **kwargs):
833839

834840
return init_func(self, *args, **kwargs)
835841

842+
wrapper.__signature__ = inspect.signature(init_func)
836843
return wrapper
837844

838845

839-
def legacy_reduction_special_decorator(init_func):
846+
def legacy_reduction_special_decorator(
847+
init_func: Callable[_InputT, _RetT],
848+
) -> Callable[_InputT, _RetT]:
840849
"""
841-
Specialized decorator: add CrossEntropyLoss / KLDivLoss special case judgment based on general logic.
850+
Specialized decorator: add CrossEntropyLoss / KLDivLoss special case judgment
851+
based on the general legacy_reduction_decorator logic.
842852
"""
843853

844854
@functools.wraps(init_func)
845-
def wrapper(self, *args, **kwargs):
846-
cls_name = self.__class__.__name__
855+
def wrapper(self, *args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
856+
cls_name = init_func.__qualname__.split(".")[0]
847857
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
848858
cls_name, args, kwargs
849859
)
@@ -863,4 +873,5 @@ def wrapper(self, *args, **kwargs):
863873
raise_deprecated_error(cls_name, reduce_val, size_avg_val)
864874
return init_func(self, *args, **kwargs)
865875

876+
wrapper.__signature__ = inspect.signature(init_func)
866877
return wrapper

0 commit comments

Comments
 (0)