Skip to content
Merged
23 changes: 22 additions & 1 deletion python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import paddle
from paddle import base, in_dynamic_mode
from paddle.base.framework import in_dynamic_or_pir_mode
from paddle.utils.decorator_utils import ParamAliasDecorator
from paddle.utils.decorator_utils import (
ParamAliasDecorator,
legacy_reduction_decorator,
legacy_reduction_special_decorator,
)

from .. import functional as F
from .layers import Layer
Expand Down Expand Up @@ -121,6 +125,7 @@ class BCEWithLogitsLoss(Layer):
pos_weight: Tensor | None
name: str | None

@legacy_reduction_decorator
def __init__(
self,
weight: Tensor | None = None,
Expand Down Expand Up @@ -418,6 +423,7 @@ class CrossEntropyLoss(Layer):
label_smoothing: float
name: str | None

@legacy_reduction_special_decorator
def __init__(
self,
weight: Tensor | None = None,
Expand Down Expand Up @@ -656,6 +662,7 @@ class MSELoss(Layer):

reduction: _ReduceMode

@legacy_reduction_decorator
def __init__(self, reduction: _ReduceMode = 'mean'):
super().__init__()
if reduction not in ['sum', 'mean', 'none']:
Expand Down Expand Up @@ -759,6 +766,7 @@ class L1Loss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self, reduction: _ReduceMode = 'mean', name: str | None = None
) -> None:
Expand Down Expand Up @@ -849,6 +857,7 @@ class BCELoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
weight: Tensor | None = None,
Expand Down Expand Up @@ -961,6 +970,7 @@ class NLLLoss(Layer):

"""

@legacy_reduction_decorator
def __init__(
self,
weight: Tensor | None = None,
Expand Down Expand Up @@ -1049,6 +1059,7 @@ class PoissonNLLLoss(Layer):

"""

@legacy_reduction_decorator
def __init__(
self,
log_input: bool = True,
Expand Down Expand Up @@ -1180,6 +1191,7 @@ class KLDivLoss(Layer):
reduction: _ReduceMode
log_target: bool

@legacy_reduction_special_decorator
def __init__(
self, reduction: _ReduceMode = 'mean', log_target: bool = False
) -> None:
Expand Down Expand Up @@ -1252,6 +1264,7 @@ class MarginRankingLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
margin: float = 0.0,
Expand Down Expand Up @@ -1525,6 +1538,7 @@ class SmoothL1Loss(Layer):
delta: float
name: str | None

@legacy_reduction_decorator
def __init__(
self,
reduction: _ReduceMode = 'mean',
Expand Down Expand Up @@ -1614,6 +1628,7 @@ class MultiLabelSoftMarginLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
weight: Tensor | None = None,
Expand Down Expand Up @@ -1726,6 +1741,7 @@ class HingeEmbeddingLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
margin: float = 1.0,
Expand Down Expand Up @@ -1824,6 +1840,7 @@ class CosineEmbeddingLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
margin: float = 0,
Expand Down Expand Up @@ -2061,6 +2078,7 @@ class TripletMarginLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
margin: float = 1.0,
Expand Down Expand Up @@ -2177,6 +2195,7 @@ class MultiMarginLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
p: int = 1,
Expand Down Expand Up @@ -2272,6 +2291,7 @@ class MultiLabelMarginLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self,
reduction: _ReduceMode = 'mean',
Expand Down Expand Up @@ -2360,6 +2380,7 @@ class SoftMarginLoss(Layer):
reduction: _ReduceMode
name: str | None

@legacy_reduction_decorator
def __init__(
self, reduction: _ReduceMode = 'mean', name: str | None = None
) -> None:
Expand Down
137 changes: 137 additions & 0 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

_InputT = ParamSpec("_InputT")
_RetT = TypeVar("_RetT")
_SENTINEL = object()


def _is_int_or_scalar_tensor(x):
Expand Down Expand Up @@ -736,3 +737,139 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
return wrapper

return decorator


_SA0_RD1 = {'size_average': 0, 'reduce': 1}
_SA1_RD2 = {'size_average': 1, 'reduce': 2}
_SA1_RD3 = {'size_average': 1, 'reduce': 3}
_SA3_RD4 = {'size_average': 3, 'reduce': 4}
_SA4_RD5 = {'size_average': 4, 'reduce': 5}
_SA2_RD4 = {'size_average': 2, 'reduce': 4}

LEGACY_POS: dict[str, dict[str, int]] = {
**dict.fromkeys(
(
'L1Loss',
'MSELoss',
'KLDivLoss',
'SmoothL1Loss',
'SoftMarginLoss',
'MultiLabelMarginLoss',
),
_SA0_RD1,
),
**dict.fromkeys(
(
'BCELoss',
'BCEWithLogitsLoss',
'MultiLabelSoftMarginLoss',
'HingeEmbeddingLoss',
'CosineEmbeddingLoss',
'MarginRankingLoss',
),
_SA1_RD2,
),
'CrossEntropyLoss': _SA1_RD3,
'NLLLoss': _SA1_RD3,
'PoissonNLLLoss': _SA2_RD4,
'MultiMarginLoss': _SA3_RD4,
'TripletMarginLoss': _SA4_RD5,
}


def compute_legacy_reduction(reduce_val, size_average_val):
if reduce_val is False:
return 'none'
if reduce_val is True:
return 'sum' if size_average_val is False else 'mean'
return 'sum' if size_average_val is False else 'mean'


def get_legacy_reduce_and_size_average(cls_name, args, kwargs):
reduce_val = ''
size_avg_val = ''
pos = LEGACY_POS.get(cls_name)
idx = pos.get('size_average')
if 'size_average' in kwargs:
size_avg_val = kwargs.pop('size_average')
elif len(args) > idx:
v = args[idx]
if type(v) is bool:
size_avg_val = v
idx = pos.get('reduce')
if 'reduce' in kwargs:
reduce_val = kwargs.pop('reduce')
elif len(args) > idx:
v = args[idx]
if type(v) is bool:
reduce_val = v
return reduce_val, size_avg_val


def raise_deprecated_error(cls_name, reduce_val, size_avg_val):
suggested = compute_legacy_reduction(reduce_val, size_avg_val)
reduce_val = None if reduce_val == '' else reduce_val
size_avg_val = None if size_avg_val == '' else size_avg_val
raise ValueError(
f"[Deprecated] '{cls_name}' no longer supports 'reduce' or 'size_average'."
f"\nDetected: reduce={reduce_val}, size_average={size_avg_val}"
f"\nPlease use: reduction='{suggested}' instead."
)


def legacy_reduction_decorator(
init_func: Callable[_InputT, _RetT],
) -> Callable[_InputT, _RetT]:
"""
Function decorator for __init__: intercept deprecated 'reduce' and 'size_average'.
"""

@functools.wraps(init_func)
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
# avoid subclass calling parent class init, causing cls_name to be inaccurate
cls_name = init_func.__qualname__.split(".")[0]
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
cls_name, args[1:], kwargs
)
if reduce_val != '' or size_avg_val != '':
raise_deprecated_error(cls_name, reduce_val, size_avg_val)

return init_func(*args, **kwargs)

wrapper.__signature__ = inspect.signature(init_func)
return wrapper


def legacy_reduction_special_decorator(
init_func: Callable[_InputT, _RetT],
) -> Callable[_InputT, _RetT]:
"""
Specialized decorator: add CrossEntropyLoss / KLDivLoss special case judgment
based on the general legacy_reduction_decorator logic.
"""

@functools.wraps(init_func)
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
cls_name = init_func.__qualname__.split(".")[0]
use_args = args[1:]
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
cls_name, use_args, kwargs
)
if reduce_val != '' or size_avg_val != '':
if not (
(
cls_name == 'CrossEntropyLoss'
and len(use_args) > 2
and use_args[2] in {'mean', 'sum', 'none'}
)
or (
cls_name == 'KLDivLoss'
and len(use_args) > 0
and use_args[0] in {'mean', 'sum', 'none', 'batchmean'}
)
):
raise_deprecated_error(cls_name, reduce_val, size_avg_val)
return init_func(*args, **kwargs)

wrapper.__signature__ = inspect.signature(init_func)
return wrapper
1 change: 1 addition & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ endif()

set_tests_properties(test_profiler PROPERTIES TIMEOUT 120)
set_tests_properties(test_cross_entropy_loss PROPERTIES TIMEOUT 180)
set_tests_properties(test_legacy_loss_args PROPERTIES TIMEOUT 10)
set_tests_properties(test_activation_nn_grad PROPERTIES TIMEOUT 250)
set_tests_properties(test_empty_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_elementwise_div_op PROPERTIES TIMEOUT 120)
Expand Down
Loading
Loading