diff --git a/segmentation_models_pytorch/losses/focal.py b/segmentation_models_pytorch/losses/focal.py index 0e055162..d26acb52 100644 --- a/segmentation_models_pytorch/losses/focal.py +++ b/segmentation_models_pytorch/losses/focal.py @@ -45,6 +45,7 @@ def __init__( self.mode = mode self.ignore_index = ignore_index + self.reduction = reduction self.focal_loss_fn = partial( focal_loss_with_logits, alpha=alpha,