Skip to content

Commit 73a14d5

Browse files
zty-kingLittleHeroZZZX
authored andcommitted
[API Compatibility] optimize size_average & reduce error message (PaddlePaddle#76221)
* fix_loss_args_legacy * adapt positional vars * optimize code * optimze code * optimize code/root/miniconda3/envs/qwen_DCP_test/bin/python /home/paddle/test/legacy_test/test_legacy_loss_args.py * fix the bug * fix code * add type hint * fix code
1 parent c2c69dc commit 73a14d5

File tree

4 files changed

+346
-1
lines changed

4 files changed

+346
-1
lines changed

python/paddle/nn/layer/loss.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import paddle
2020
from paddle import base, in_dynamic_mode
2121
from paddle.base.framework import in_dynamic_or_pir_mode
22-
from paddle.utils.decorator_utils import ParamAliasDecorator
22+
from paddle.utils.decorator_utils import (
23+
ParamAliasDecorator,
24+
legacy_reduction_decorator,
25+
legacy_reduction_special_decorator,
26+
)
2327

2428
from .. import functional as F
2529
from .layers import Layer
@@ -121,6 +125,7 @@ class BCEWithLogitsLoss(Layer):
121125
pos_weight: Tensor | None
122126
name: str | None
123127

128+
@legacy_reduction_decorator
124129
def __init__(
125130
self,
126131
weight: Tensor | None = None,
@@ -418,6 +423,7 @@ class CrossEntropyLoss(Layer):
418423
label_smoothing: float
419424
name: str | None
420425

426+
@legacy_reduction_special_decorator
421427
def __init__(
422428
self,
423429
weight: Tensor | None = None,
@@ -656,6 +662,7 @@ class MSELoss(Layer):
656662

657663
reduction: _ReduceMode
658664

665+
@legacy_reduction_decorator
659666
def __init__(self, reduction: _ReduceMode = 'mean'):
660667
super().__init__()
661668
if reduction not in ['sum', 'mean', 'none']:
@@ -759,6 +766,7 @@ class L1Loss(Layer):
759766
reduction: _ReduceMode
760767
name: str | None
761768

769+
@legacy_reduction_decorator
762770
def __init__(
763771
self, reduction: _ReduceMode = 'mean', name: str | None = None
764772
) -> None:
@@ -849,6 +857,7 @@ class BCELoss(Layer):
849857
reduction: _ReduceMode
850858
name: str | None
851859

860+
@legacy_reduction_decorator
852861
def __init__(
853862
self,
854863
weight: Tensor | None = None,
@@ -961,6 +970,7 @@ class NLLLoss(Layer):
961970
962971
"""
963972

973+
@legacy_reduction_decorator
964974
def __init__(
965975
self,
966976
weight: Tensor | None = None,
@@ -1049,6 +1059,7 @@ class PoissonNLLLoss(Layer):
10491059
10501060
"""
10511061

1062+
@legacy_reduction_decorator
10521063
def __init__(
10531064
self,
10541065
log_input: bool = True,
@@ -1180,6 +1191,7 @@ class KLDivLoss(Layer):
11801191
reduction: _ReduceMode
11811192
log_target: bool
11821193

1194+
@legacy_reduction_special_decorator
11831195
def __init__(
11841196
self, reduction: _ReduceMode = 'mean', log_target: bool = False
11851197
) -> None:
@@ -1252,6 +1264,7 @@ class MarginRankingLoss(Layer):
12521264
reduction: _ReduceMode
12531265
name: str | None
12541266

1267+
@legacy_reduction_decorator
12551268
def __init__(
12561269
self,
12571270
margin: float = 0.0,
@@ -1525,6 +1538,7 @@ class SmoothL1Loss(Layer):
15251538
delta: float
15261539
name: str | None
15271540

1541+
@legacy_reduction_decorator
15281542
def __init__(
15291543
self,
15301544
reduction: _ReduceMode = 'mean',
@@ -1614,6 +1628,7 @@ class MultiLabelSoftMarginLoss(Layer):
16141628
reduction: _ReduceMode
16151629
name: str | None
16161630

1631+
@legacy_reduction_decorator
16171632
def __init__(
16181633
self,
16191634
weight: Tensor | None = None,
@@ -1726,6 +1741,7 @@ class HingeEmbeddingLoss(Layer):
17261741
reduction: _ReduceMode
17271742
name: str | None
17281743

1744+
@legacy_reduction_decorator
17291745
def __init__(
17301746
self,
17311747
margin: float = 1.0,
@@ -1824,6 +1840,7 @@ class CosineEmbeddingLoss(Layer):
18241840
reduction: _ReduceMode
18251841
name: str | None
18261842

1843+
@legacy_reduction_decorator
18271844
def __init__(
18281845
self,
18291846
margin: float = 0,
@@ -2061,6 +2078,7 @@ class TripletMarginLoss(Layer):
20612078
reduction: _ReduceMode
20622079
name: str | None
20632080

2081+
@legacy_reduction_decorator
20642082
def __init__(
20652083
self,
20662084
margin: float = 1.0,
@@ -2177,6 +2195,7 @@ class MultiMarginLoss(Layer):
21772195
reduction: _ReduceMode
21782196
name: str | None
21792197

2198+
@legacy_reduction_decorator
21802199
def __init__(
21812200
self,
21822201
p: int = 1,
@@ -2272,6 +2291,7 @@ class MultiLabelMarginLoss(Layer):
22722291
reduction: _ReduceMode
22732292
name: str | None
22742293

2294+
@legacy_reduction_decorator
22752295
def __init__(
22762296
self,
22772297
reduction: _ReduceMode = 'mean',
@@ -2360,6 +2380,7 @@ class SoftMarginLoss(Layer):
23602380
reduction: _ReduceMode
23612381
name: str | None
23622382

2383+
@legacy_reduction_decorator
23632384
def __init__(
23642385
self, reduction: _ReduceMode = 'mean', name: str | None = None
23652386
) -> None:

python/paddle/utils/decorator_utils.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
_InputT = ParamSpec("_InputT")
3030
_RetT = TypeVar("_RetT")
31+
_SENTINEL = object()
3132

3233

3334
def _is_int_or_scalar_tensor(x):
@@ -736,3 +737,139 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
736737
return wrapper
737738

738739
return decorator
740+
741+
742+
_SA0_RD1 = {'size_average': 0, 'reduce': 1}
743+
_SA1_RD2 = {'size_average': 1, 'reduce': 2}
744+
_SA1_RD3 = {'size_average': 1, 'reduce': 3}
745+
_SA3_RD4 = {'size_average': 3, 'reduce': 4}
746+
_SA4_RD5 = {'size_average': 4, 'reduce': 5}
747+
_SA2_RD4 = {'size_average': 2, 'reduce': 4}
748+
749+
LEGACY_POS: dict[str, dict[str, int]] = {
750+
**dict.fromkeys(
751+
(
752+
'L1Loss',
753+
'MSELoss',
754+
'KLDivLoss',
755+
'SmoothL1Loss',
756+
'SoftMarginLoss',
757+
'MultiLabelMarginLoss',
758+
),
759+
_SA0_RD1,
760+
),
761+
**dict.fromkeys(
762+
(
763+
'BCELoss',
764+
'BCEWithLogitsLoss',
765+
'MultiLabelSoftMarginLoss',
766+
'HingeEmbeddingLoss',
767+
'CosineEmbeddingLoss',
768+
'MarginRankingLoss',
769+
),
770+
_SA1_RD2,
771+
),
772+
'CrossEntropyLoss': _SA1_RD3,
773+
'NLLLoss': _SA1_RD3,
774+
'PoissonNLLLoss': _SA2_RD4,
775+
'MultiMarginLoss': _SA3_RD4,
776+
'TripletMarginLoss': _SA4_RD5,
777+
}
778+
779+
780+
def compute_legacy_reduction(reduce_val, size_average_val):
781+
if reduce_val is False:
782+
return 'none'
783+
if reduce_val is True:
784+
return 'sum' if size_average_val is False else 'mean'
785+
return 'sum' if size_average_val is False else 'mean'
786+
787+
788+
def get_legacy_reduce_and_size_average(cls_name, args, kwargs):
789+
reduce_val = ''
790+
size_avg_val = ''
791+
pos = LEGACY_POS.get(cls_name)
792+
idx = pos.get('size_average')
793+
if 'size_average' in kwargs:
794+
size_avg_val = kwargs.pop('size_average')
795+
elif len(args) > idx:
796+
v = args[idx]
797+
if type(v) is bool:
798+
size_avg_val = v
799+
idx = pos.get('reduce')
800+
if 'reduce' in kwargs:
801+
reduce_val = kwargs.pop('reduce')
802+
elif len(args) > idx:
803+
v = args[idx]
804+
if type(v) is bool:
805+
reduce_val = v
806+
return reduce_val, size_avg_val
807+
808+
809+
def raise_deprecated_error(cls_name, reduce_val, size_avg_val):
810+
suggested = compute_legacy_reduction(reduce_val, size_avg_val)
811+
reduce_val = None if reduce_val == '' else reduce_val
812+
size_avg_val = None if size_avg_val == '' else size_avg_val
813+
raise ValueError(
814+
f"[Deprecated] '{cls_name}' no longer supports 'reduce' or 'size_average'."
815+
f"\nDetected: reduce={reduce_val}, size_average={size_avg_val}"
816+
f"\nPlease use: reduction='{suggested}' instead."
817+
)
818+
819+
820+
def legacy_reduction_decorator(
821+
init_func: Callable[_InputT, _RetT],
822+
) -> Callable[_InputT, _RetT]:
823+
"""
824+
Function decorator for __init__: intercept deprecated 'reduce' and 'size_average'.
825+
"""
826+
827+
@functools.wraps(init_func)
828+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
829+
# avoid subclass calling parent class init, causing cls_name to be inaccurate
830+
cls_name = init_func.__qualname__.split(".")[0]
831+
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
832+
cls_name, args[1:], kwargs
833+
)
834+
if reduce_val != '' or size_avg_val != '':
835+
raise_deprecated_error(cls_name, reduce_val, size_avg_val)
836+
837+
return init_func(*args, **kwargs)
838+
839+
wrapper.__signature__ = inspect.signature(init_func)
840+
return wrapper
841+
842+
843+
def legacy_reduction_special_decorator(
844+
init_func: Callable[_InputT, _RetT],
845+
) -> Callable[_InputT, _RetT]:
846+
"""
847+
Specialized decorator: add CrossEntropyLoss / KLDivLoss special case judgment
848+
based on the general legacy_reduction_decorator logic.
849+
"""
850+
851+
@functools.wraps(init_func)
852+
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
853+
cls_name = init_func.__qualname__.split(".")[0]
854+
use_args = args[1:]
855+
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
856+
cls_name, use_args, kwargs
857+
)
858+
if reduce_val != '' or size_avg_val != '':
859+
if not (
860+
(
861+
cls_name == 'CrossEntropyLoss'
862+
and len(use_args) > 2
863+
and use_args[2] in {'mean', 'sum', 'none'}
864+
)
865+
or (
866+
cls_name == 'KLDivLoss'
867+
and len(use_args) > 0
868+
and use_args[0] in {'mean', 'sum', 'none', 'batchmean'}
869+
)
870+
):
871+
raise_deprecated_error(cls_name, reduce_val, size_avg_val)
872+
return init_func(*args, **kwargs)
873+
874+
wrapper.__signature__ = inspect.signature(init_func)
875+
return wrapper

test/legacy_test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,7 @@ endif()
874874

875875
set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
876876
set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 180)
877+
set_tests_properties(test_legacy_loss_args PROPERTIES TIMEOUT 10)
877878
set_tests_properties(test_activation_nn_grad PROPERTIES TIMEOUT 250)
878879
set_tests_properties(test_empty_op PROPERTIES TIMEOUT 120)
879880
set_tests_properties(test_elementwise_div_op PROPERTIES TIMEOUT 120)

0 commit comments

Comments
 (0)