@@ -790,56 +790,62 @@ def check_deprecated_params_on_init(init_func):
790790
791791 @functools .wraps (init_func )
792792 def wrapper (self , * args , ** kwargs ):
793+ # Step 1: Extract legacy params from positional args (convert to kwargs-like values)
794+ reduce_from_pos = _SENTINEL
795+ size_avg_from_pos = _SENTINEL
796+
797+ cls_name = self .__class__ .__name__
798+ pos = LEGACY_POS .get (cls_name )
799+ if pos :
800+ # Check size_average in positional args
801+ idx = pos .get ('size_average' )
802+ if idx is not None and len (args ) > idx :
803+ v = args [idx ]
804+ if type (v ) is bool :
805+ size_avg_from_pos = v
806+
807+ # Check reduce in positional args
808+ idx = pos .get ('reduce' )
809+ if idx is not None and len (args ) > idx :
810+ v = args [idx ]
811+ if type (v ) is bool :
812+ # Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
813+ # treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
814+ if cls_name == 'CrossEntropyLoss' and args [2 ] in {
815+ 'mean' ,
816+ 'sum' ,
817+ 'none' ,
818+ }:
819+ # 视为 soft_label,跳过
820+ pass
821+ elif cls_name == 'KLDivLoss' and args [0 ] in {
822+ 'mean' ,
823+ 'sum' ,
824+ 'none' ,
825+ 'batchmean' ,
826+ }:
827+ # 视为 log_target,跳过
828+ pass
829+ else :
830+ reduce_from_pos = v
831+
832+ # Step 2: Extract legacy params from kwargs (kwargs take priority)
793833 reduce_raw = kwargs .pop ('reduce' , _SENTINEL )
794834 size_avg_raw = kwargs .pop ('size_average' , _SENTINEL )
795835
796- has_reduce = reduce_raw is not _SENTINEL
797- has_size_avg = size_avg_raw is not _SENTINEL
798-
799- # If not provided via kwargs, try positional indices per class mapping
800- if not (has_reduce and has_size_avg ):
801- cls_name = self .__class__ .__name__
802- pos = LEGACY_POS .get (cls_name )
803- if pos :
804- if not has_size_avg :
805- idx = pos .get ('size_average' )
806- if idx is not None and len (args ) > idx :
807- v = args [idx ]
808- if type (v ) is bool :
809- size_avg_raw = v
810- has_size_avg = True
811- if not has_reduce :
812- idx = pos .get ('reduce' )
813- if idx is not None and len (args ) > idx :
814- v = args [idx ]
815- if type (v ) is bool :
816- # Special guard: CrossEntropyLoss idx=3 is soft_label (bool) in Paddle;
817- # treat it as legacy 'reduce' ONLY if arg at idx=2 is NOT a valid reduction string
818- if cls_name == 'CrossEntropyLoss' and args [2 ] in {
819- 'mean' ,
820- 'sum' ,
821- 'none' ,
822- }:
823- # 视为 soft_label,跳过
824- pass
825- elif cls_name == 'KLDivLoss' and args [0 ] in {
826- 'mean' ,
827- 'sum' ,
828- 'none' ,
829- 'batchmean' ,
830- }:
831- # 视为 log_target,跳过
832- pass
833- else :
834- reduce_raw = v
835- has_reduce = True
836-
837- if has_reduce or has_size_avg :
836+ # Step 3: Use kwargs if present, otherwise use positional args
837+ if reduce_raw is _SENTINEL :
838+ reduce_raw = reduce_from_pos
839+ if size_avg_raw is _SENTINEL :
840+ size_avg_raw = size_avg_from_pos
841+
842+ # Step 4: Check if any legacy params were found and raise error
843+ if reduce_raw is not _SENTINEL or size_avg_raw is not _SENTINEL :
838844 reduce_val = None if reduce_raw is _SENTINEL else reduce_raw
839845 size_avg_val = None if size_avg_raw is _SENTINEL else size_avg_raw
840846 suggested = compute_legacy_reduction (reduce_val , size_avg_val )
841847 raise ValueError (
842- f"[Deprecated] '{ self . __class__ . __name__ } ' no longer supports 'reduce' or 'size_average'."
848+ f"[Deprecated] '{ cls_name } ' no longer supports 'reduce' or 'size_average'."
843849 f"\n Detected: reduce={ reduce_val } , size_average={ size_avg_val } "
844850 f"\n Please use: reduction='{ suggested } ' instead."
845851 )
0 commit comments