Skip to content

Commit 512cac4

Browse files
committed
optimize code
1 parent c0393c1 commit 512cac4

File tree

1 file changed

+59
-43
lines changed

1 file changed

+59
-43
lines changed

python/paddle/utils/decorator_utils.py

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -790,56 +790,72 @@ def check_deprecated_params_on_init(init_func):
790790

791791
@functools.wraps(init_func)
792792
def wrapper(self, *args, **kwargs):
793+
# Step 1: Extract legacy params from positional args (convert to kwargs-like values)
794+
reduce_from_pos = _SENTINEL
795+
size_avg_from_pos = _SENTINEL
796+
797+
cls_name = self.__class__.__name__
798+
pos = LEGACY_POS.get(cls_name)
799+
if pos:
800+
# Check size_average in positional args
801+
idx = pos.get('size_average')
802+
if idx is not None and len(args) > idx:
803+
v = args[idx]
804+
if type(v) is bool:
805+
size_avg_from_pos = v
806+
807+
# Check reduce in positional args
808+
idx = pos.get('reduce')
809+
if idx is not None and len(args) > idx:
810+
v = args[idx]
811+
if type(v) is bool:
812+
# Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
813+
# treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
814+
if (
815+
cls_name == 'CrossEntropyLoss'
816+
and len(args) > 2
817+
and args[2]
818+
in {
819+
'mean',
820+
'sum',
821+
'none',
822+
}
823+
):
824+
# 视为 soft_label,跳过
825+
pass
826+
elif (
827+
cls_name == 'KLDivLoss'
828+
and len(args) > 0
829+
and args[0]
830+
in {
831+
'mean',
832+
'sum',
833+
'none',
834+
'batchmean',
835+
}
836+
):
837+
# 视为 log_target,跳过
838+
pass
839+
else:
840+
reduce_from_pos = v
841+
842+
# Step 2: Extract legacy params from kwargs (kwargs take priority)
793843
reduce_raw = kwargs.pop('reduce', _SENTINEL)
794844
size_avg_raw = kwargs.pop('size_average', _SENTINEL)
795845

796-
has_reduce = reduce_raw is not _SENTINEL
797-
has_size_avg = size_avg_raw is not _SENTINEL
798-
799-
# If not provided via kwargs, try positional indices per class mapping
800-
if not (has_reduce and has_size_avg):
801-
cls_name = self.__class__.__name__
802-
pos = LEGACY_POS.get(cls_name)
803-
if pos:
804-
if not has_size_avg:
805-
idx = pos.get('size_average')
806-
if idx is not None and len(args) > idx:
807-
v = args[idx]
808-
if type(v) is bool:
809-
size_avg_raw = v
810-
has_size_avg = True
811-
if not has_reduce:
812-
idx = pos.get('reduce')
813-
if idx is not None and len(args) > idx:
814-
v = args[idx]
815-
if type(v) is bool:
816-
# Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
817-
# treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
818-
if cls_name == 'CrossEntropyLoss' and args[2] in {
819-
'mean',
820-
'sum',
821-
'none',
822-
}:
823-
# 视为 soft_label,跳过
824-
pass
825-
elif cls_name == 'KLDivLoss' and args[0] in {
826-
'mean',
827-
'sum',
828-
'none',
829-
'batchmean',
830-
}:
831-
# 视为 log_target,跳过
832-
pass
833-
else:
834-
reduce_raw = v
835-
has_reduce = True
836-
837-
if has_reduce or has_size_avg:
846+
# Step 3: Use kwargs if present, otherwise use positional args
847+
if reduce_raw is _SENTINEL:
848+
reduce_raw = reduce_from_pos
849+
if size_avg_raw is _SENTINEL:
850+
size_avg_raw = size_avg_from_pos
851+
852+
# Step 4: Check if any legacy params were found and raise error
853+
if reduce_raw is not _SENTINEL or size_avg_raw is not _SENTINEL:
838854
reduce_val = None if reduce_raw is _SENTINEL else reduce_raw
839855
size_avg_val = None if size_avg_raw is _SENTINEL else size_avg_raw
840856
suggested = compute_legacy_reduction(reduce_val, size_avg_val)
841857
raise ValueError(
842-
f"[Deprecated] '{self.__class__.__name__}' no longer supports 'reduce' or 'size_average'."
858+
f"[Deprecated] '{cls_name}' no longer supports 'reduce' or 'size_average'."
843859
f"\nDetected: reduce={reduce_val}, size_average={size_avg_val}"
844860
f"\nPlease use: reduction='{suggested}' instead."
845861
)

0 commit comments

Comments
 (0)