@@ -790,56 +790,72 @@ def check_deprecated_params_on_init(init_func):
790790
791791 @functools .wraps (init_func )
792792 def wrapper (self , * args , ** kwargs ):
793+ # Step 1: Extract legacy params from positional args (convert to kwargs-like values)
794+ reduce_from_pos = _SENTINEL
795+ size_avg_from_pos = _SENTINEL
796+
797+ cls_name = self .__class__ .__name__
798+ pos = LEGACY_POS .get (cls_name )
799+ if pos :
800+ # Check size_average in positional args
801+ idx = pos .get ('size_average' )
802+ if idx is not None and len (args ) > idx :
803+ v = args [idx ]
804+ if type (v ) is bool :
805+ size_avg_from_pos = v
806+
807+ # Check reduce in positional args
808+ idx = pos .get ('reduce' )
809+ if idx is not None and len (args ) > idx :
810+ v = args [idx ]
811+ if type (v ) is bool :
812+ # Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
813+ # treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
814+ if (
815+ cls_name == 'CrossEntropyLoss'
816+ and len (args ) > 2
817+ and args [2 ]
818+ in {
819+ 'mean' ,
820+ 'sum' ,
821+ 'none' ,
822+ }
823+ ):
824+ # 视为 soft_label,跳过
825+ pass
826+ elif (
827+ cls_name == 'KLDivLoss'
828+ and len (args ) > 0
829+ and args [0 ]
830+ in {
831+ 'mean' ,
832+ 'sum' ,
833+ 'none' ,
834+ 'batchmean' ,
835+ }
836+ ):
837+ # 视为 log_target,跳过
838+ pass
839+ else :
840+ reduce_from_pos = v
841+
842+ # Step 2: Extract legacy params from kwargs (kwargs take priority)
793843 reduce_raw = kwargs .pop ('reduce' , _SENTINEL )
794844 size_avg_raw = kwargs .pop ('size_average' , _SENTINEL )
795845
796- has_reduce = reduce_raw is not _SENTINEL
797- has_size_avg = size_avg_raw is not _SENTINEL
798-
799- # If not provided via kwargs, try positional indices per class mapping
800- if not (has_reduce and has_size_avg ):
801- cls_name = self .__class__ .__name__
802- pos = LEGACY_POS .get (cls_name )
803- if pos :
804- if not has_size_avg :
805- idx = pos .get ('size_average' )
806- if idx is not None and len (args ) > idx :
807- v = args [idx ]
808- if type (v ) is bool :
809- size_avg_raw = v
810- has_size_avg = True
811- if not has_reduce :
812- idx = pos .get ('reduce' )
813- if idx is not None and len (args ) > idx :
814- v = args [idx ]
815- if type (v ) is bool :
816- # Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
817- # treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
818- if cls_name == 'CrossEntropyLoss' and args [2 ] in {
819- 'mean' ,
820- 'sum' ,
821- 'none' ,
822- }:
823- # 视为 soft_label,跳过
824- pass
825- elif cls_name == 'KLDivLoss' and args [0 ] in {
826- 'mean' ,
827- 'sum' ,
828- 'none' ,
829- 'batchmean' ,
830- }:
831- # 视为 log_target,跳过
832- pass
833- else :
834- reduce_raw = v
835- has_reduce = True
836-
837- if has_reduce or has_size_avg :
846+ # Step 3: Use kwargs if present, otherwise use positional args
847+ if reduce_raw is _SENTINEL :
848+ reduce_raw = reduce_from_pos
849+ if size_avg_raw is _SENTINEL :
850+ size_avg_raw = size_avg_from_pos
851+
852+ # Step 4: Check if any legacy params were found and raise error
853+ if reduce_raw is not _SENTINEL or size_avg_raw is not _SENTINEL :
838854 reduce_val = None if reduce_raw is _SENTINEL else reduce_raw
839855 size_avg_val = None if size_avg_raw is _SENTINEL else size_avg_raw
840856 suggested = compute_legacy_reduction (reduce_val , size_avg_val )
841857 raise ValueError (
842- f"[Deprecated] '{ self . __class__ . __name__ } ' no longer supports 'reduce' or 'size_average'."
858+ f"[Deprecated] '{ cls_name } ' no longer supports 'reduce' or 'size_average'."
843859 f"\n Detected: reduce={ reduce_val } , size_average={ size_avg_val } "
844860 f"\n Please use: reduction='{ suggested } ' instead."
845861 )
0 commit comments