Skip to content

Commit 750f294

Browse files
committed
optimize code
1 parent c0393c1 commit 750f294

File tree

1 file changed

+49
-43
lines changed

1 file changed

+49
-43
lines changed

python/paddle/utils/decorator_utils.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -790,56 +790,62 @@ def check_deprecated_params_on_init(init_func):
790790

791791
@functools.wraps(init_func)
792792
def wrapper(self, *args, **kwargs):
793+
# Step 1: Extract legacy params from positional args (convert to kwargs-like values)
794+
reduce_from_pos = _SENTINEL
795+
size_avg_from_pos = _SENTINEL
796+
797+
cls_name = self.__class__.__name__
798+
pos = LEGACY_POS.get(cls_name)
799+
if pos:
800+
# Check size_average in positional args
801+
idx = pos.get('size_average')
802+
if idx is not None and len(args) > idx:
803+
v = args[idx]
804+
if type(v) is bool:
805+
size_avg_from_pos = v
806+
807+
# Check reduce in positional args
808+
idx = pos.get('reduce')
809+
if idx is not None and len(args) > idx:
810+
v = args[idx]
811+
if type(v) is bool:
812+
# Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
813+
# treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
814+
if cls_name == 'CrossEntropyLoss' and args[2] in {
815+
'mean',
816+
'sum',
817+
'none',
818+
}:
819+
# 视为 soft_label,跳过
820+
pass
821+
elif cls_name == 'KLDivLoss' and args[0] in {
822+
'mean',
823+
'sum',
824+
'none',
825+
'batchmean',
826+
}:
827+
# 视为 log_target,跳过
828+
pass
829+
else:
830+
reduce_from_pos = v
831+
832+
# Step 2: Extract legacy params from kwargs (kwargs take priority)
793833
reduce_raw = kwargs.pop('reduce', _SENTINEL)
794834
size_avg_raw = kwargs.pop('size_average', _SENTINEL)
795835

796-
has_reduce = reduce_raw is not _SENTINEL
797-
has_size_avg = size_avg_raw is not _SENTINEL
798-
799-
# If not provided via kwargs, try positional indices per class mapping
800-
if not (has_reduce and has_size_avg):
801-
cls_name = self.__class__.__name__
802-
pos = LEGACY_POS.get(cls_name)
803-
if pos:
804-
if not has_size_avg:
805-
idx = pos.get('size_average')
806-
if idx is not None and len(args) > idx:
807-
v = args[idx]
808-
if type(v) is bool:
809-
size_avg_raw = v
810-
has_size_avg = True
811-
if not has_reduce:
812-
idx = pos.get('reduce')
813-
if idx is not None and len(args) > idx:
814-
v = args[idx]
815-
if type(v) is bool:
816-
# Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
817-
# treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
818-
if cls_name == 'CrossEntropyLoss' and args[2] in {
819-
'mean',
820-
'sum',
821-
'none',
822-
}:
823-
# 视为 soft_label,跳过
824-
pass
825-
elif cls_name == 'KLDivLoss' and args[0] in {
826-
'mean',
827-
'sum',
828-
'none',
829-
'batchmean',
830-
}:
831-
# 视为 log_target,跳过
832-
pass
833-
else:
834-
reduce_raw = v
835-
has_reduce = True
836-
837-
if has_reduce or has_size_avg:
836+
# Step 3: Use kwargs if present, otherwise use positional args
837+
if reduce_raw is _SENTINEL:
838+
reduce_raw = reduce_from_pos
839+
if size_avg_raw is _SENTINEL:
840+
size_avg_raw = size_avg_from_pos
841+
842+
# Step 4: Check if any legacy params were found and raise error
843+
if reduce_raw is not _SENTINEL or size_avg_raw is not _SENTINEL:
838844
reduce_val = None if reduce_raw is _SENTINEL else reduce_raw
839845
size_avg_val = None if size_avg_raw is _SENTINEL else size_avg_raw
840846
suggested = compute_legacy_reduction(reduce_val, size_avg_val)
841847
raise ValueError(
842-
f"[Deprecated] '{self.__class__.__name__}' no longer supports 'reduce' or 'size_average'."
848+
f"[Deprecated] '{cls_name}' no longer supports 'reduce' or 'size_average'."
843849
f"\nDetected: reduce={reduce_val}, size_average={size_avg_val}"
844850
f"\nPlease use: reduction='{suggested}' instead."
845851
)

0 commit comments

Comments
 (0)