Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EarlyStopping with list of metrics to monitor #21001

Open
fabriciorsf opened this issue Mar 7, 2025 · 13 comments
Open

EarlyStopping with list of metrics to monitor #21001

fabriciorsf opened this issue Mar 7, 2025 · 13 comments
Assignees
Labels
type:feature The user is asking for a new feature.

Comments

@fabriciorsf
Copy link

In addition to this example:
callback = keras.callbacks.EarlyStopping(monitor='val_loss')

Allow monitoring of multiple metrics, as in this example:
callback = keras.callbacks.EarlyStopping(monitor=['val_loss', 'val_accuracy', 'val_f1measure'])

This way, training should not stop while any of these metrics get better values, not just one of them.

@rlcauvin
Copy link

rlcauvin commented Mar 8, 2025

One factor to consider when using multiple metrics for early stopping: how do you determine which epoch is "best" and should be used for restoring best weights (when restore_best_weights is True) after early stopping?

@fabriciorsf
Copy link
Author

The best epoch will be the one with the best val_loss value, but patience should not consider just one metric, but rather a list of metrics passed in the monitor parameter.
Another possibility would be to consider the best epoch by looking at the value of the first metric in the list passed to the monitor parameter, in case of a tie, look at the second, and so on.
Thank you for your quick response.

@rlcauvin
Copy link

rlcauvin commented Mar 10, 2025

Another consideration is that we want to minimize some metrics (e.g. val_loss) but maximize other metrics (e.g. val_accuracy).

It might be instructive to look at this implementation of a composite of early stopping callbacks. The iterable_condition enables the caller to determine whether all, any, or some other combination of early stopping conditions must hold. It assumes the last early stopping callback is the "conductor", which means it determines which epoch is "best".

class CompoundEarlyStopping(keras.callbacks.Callback):
  def __init__(
    self,
    callbacks: Iterable[keras.callbacks.Callback],
    iterable_condition: Callable[Iterable, bool] = all):
    
    super().__init__()
    
    self.callbacks = callbacks
    self.stopped_epoch = 0
    self.iterable_condition = iterable_condition
    
  def on_train_begin(
    self,
    logs: Dict = None):
    
    for callback in self.callbacks:
      callback.on_train_begin(logs)

  def on_train_end(
    self,
    logs: Dict = None):
    
    if self.model.stop_training:
      conductor = next(reversed(self.callbacks))
      conductor.stopped_epoch = self.stopped_epoch
      conductor.on_train_end(logs)

  def on_epoch_begin(
    self,
    epoch: int,
    logs: Dict = None):
    
    for callback in self.callbacks:
      callback.on_epoch_begin(epoch, logs)

  def on_epoch_end(
    self,
    epoch: int,
    logs: Dict = None):
    
    for callback in self.callbacks:
      callback.on_epoch_end(epoch, logs)
    
    if self.iterable_condition([callback.stopped_epoch >= max(1, epoch) for callback in self.callbacks]):
      self.stopped_epoch = epoch
      self.model.stop_training = True
    else:
      self.model.stop_training = False

  def set_model(
    self,
    model: keras.Model):
    
    super().set_model(model)
    
    for callback in self.callbacks:
      callback.set_model(model)

Here is an example of how I've used the CompoundEarlyStopping class:

patience = 3

early_stopping_loss = keras.callbacks.EarlyStopping(monitor="val_loss", patience=patience, mode="min", verbose=1)

early_stopping_accuracy = keras.callbacks.EarlyStopping(
  monitor="val_accuracy",
  patience=patience,
  mode="max",
  verbose=1)

early_stopping_auc = keras.callbacks.EarlyStopping(
  monitor="val_auc",
  patience=patience,
  mode="max",
  restore_best_weights=True,
  verbose=1)

early_stopping = CompoundEarlyStopping(
  callbacks=[early_stopping_loss, early_stopping_accuracy, early_stopping_auc],
  iterable_condition=all)

@dhantule dhantule added type:feature The user is asking for a new feature. keras-team-review-pending Pending review by a Keras team member. labels Mar 11, 2025
@SamanehSaadat
Copy link
Member

Hi @fabriciorsf

As @rlcauvin mentioned, you can achieve this by creating a custom callback.
Closing this issue. Please feel free to re-open if that's not what you're looking for.

@fabriciorsf
Copy link
Author

fabriciorsf commented Mar 12, 2025

I tried this solution, but I had some problems:

  1. Intermittently, at the end of the first epoch the fit method freezes before calculating val_metrics;
  2. When the previous problem does not occur, this error occurs:
...
Epoch 5/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 98ms/step - loss: 1.3509 - sumsqeuc_dist: 1.3499
Epoch 5: saving model to ./saved_models/my_model_epoch_005.keras
23/23 ━━━━━━━━━━━━━━━━━━━━ 5s 185ms/step - loss: 1.3383 - sumsqeuc_dist: 1.3369 - val_loss: 0.7643 - val_sumsqeuc_dist: 0.7621
Traceback (most recent call last):
Traceback (most recent call last):
...  
File "myscript.py", line 293, in train_autonem
    history = self.my_model.fit(self.train_inputs,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "....../python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "..../mypackage/utils/keras_extensions.py", line 268, in on_train_end
    conductor = self.callbacks[-1]
                ~~~~~~~~~~~~~~^^^^
TypeError: 'dict_values' object is not subscriptable

The class CompoundEarlyStopping is in file keras_extensions.py:

class CompoundEarlyStopping(keras.callbacks.Callback):
    def __init__(
        self, callbacks: Iterable[Callback],
        iterable_condition: Callable[[], bool] = all):

        super().__init__()

        self.callbacks = callbacks
        self.stopped_epoch = 0
        self.iterable_condition = iterable_condition
    
    def on_train_begin(self, logs: dict = None):
        for callback in self.callbacks:
            callback.on_train_begin(logs)

    def on_train_end(self, logs: dict = None):
        if self.model.stop_training:
            conductor = self.callbacks[-1] ## intermittent ERROR here
            conductor.stopped_epoch = self.stopped_epoch
            conductor.on_train_end(logs)

    def on_epoch_begin(self, epoch: int, logs: dict = None):
        for callback in self.callbacks:
            callback.on_epoch_begin(epoch, logs)

    def on_epoch_end(self, epoch: int, logs: dict = None):
        for callback in self.callbacks:
            callback.on_epoch_end(epoch, logs)
        if self.iterable_condition([callback.stopped_epoch >= max(1, epoch) \
                                    for callback in self.callbacks]):
            self.stopped_epoch = epoch
            self.model.stop_training = True
        else:
            self.model.stop_training = False

    def set_model(self, model: Model):
        super().set_model(model)
        for callback in self.callbacks:
            callback.set_model(model)

Note: There is a error with Callable[Iterable, bool], so I replaced it with Callable[[], bool]

I instantiate the callbacks like this:

dict_early_stopping = {val_metric_name: EarlyStopping(monitor=val_metric_name, patience=5, 
                                                      start_from_epoch=1, verbose=self.verbose) \
                       for val_metric_name in val_metrics_name}

early_stopping = CompoundEarlyStopping(
    callbacks=dict_early_stopping.values(),
    iterable_condition=all)

As a baseline, doing early_stopping = dict_early_stopping[val_metrics_name[0]] works fine.

@rlcauvin
Copy link

@fabriciorsf The original implementation of CompoundEarlyStopping.on_train_end incorrectly assumed that self.callbacks is subscriptable. In your code, dict_early_stopping.values() does not produce a subscriptable result.

I have edited the CompoundEarlyStopping code in my earlier comment to work with instances of Iterable that are not subscriptable.

I changed

conductor = self.callbacks[-1]

to

conductor = next(reversed(self.callbacks))

Let us know if it works for you.

@fabriciorsf
Copy link
Author

I tested it, and the second problem doesn't occur, but the first problem remains:

Epoch 1/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 100ms/step - loss: 1.3480 - sumsqeuc_dist: 1.3470

And it stays frozen like that.

@fabriciorsf
Copy link
Author

fabriciorsf commented Mar 12, 2025

I also tested without multithreading and everything worked fine.

So I suspect there is some problem with the CompoundEarlyStopping when training with multithreading.

For the record: without the CompoundEarlyStopping, the training with multiprocessing works fine.

@rlcauvin
Copy link

@fabriciorsf What do you mean by "multithreading" in this context?

@fabriciorsf
Copy link
Author

fabriciorsf commented Mar 12, 2025

It means using a PyDataset instance to load data during fit with use_multiprocessing=True and workers > 1.
Again: without the CompoundEarlyStopping, the training with multiprocessing works fine.

Ctrl+C while freezing:

Epoch 1/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 100ms/step - loss: 1.3480 - sumsqeuc_dist: 1.3470^CProcess Keras_worker_ForkPoolWorker-28:
Process Keras_worker_ForkPoolWorker-36:
Process Keras_worker_ForkPoolWorker-39:
Process Keras_worker_ForkPoolWorker-25:
Process Keras_worker_ForkPoolWorker-35:
Process Keras_worker_ForkPoolWorker-29:
Process Keras_worker_ForkPoolWorker-31:
Process Keras_worker_ForkPoolWorker-26:
Process Keras_worker_ForkPoolWorker-32:
Process Keras_worker_ForkPoolWorker-37:
Traceback (most recent call last):
  File ".........../myscript.py", line 853, in <module>
Process Keras_worker_ForkPoolWorker-23:
Process Keras_worker_ForkPoolWorker-27:
Process Keras_worker_ForkPoolWorker-30:
Process Keras_worker_ForkPoolWorker-40:
Process Keras_worker_ForkPoolWorker-34:
Process Keras_worker_ForkPoolWorker-41:
Process Keras_worker_ForkPoolWorker-33:
Process Keras_worker_ForkPoolWorker-44:
Process Keras_worker_ForkPoolWorker-24:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "....../python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "....../python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "....../python3.12/multiprocessing/pool.py", line 114, in worker
    task = get()
           ^^^^^
  File "....../python3.12/multiprocessing/queues.py", line 386, in get
    with self._rlock:
  File "....../python3.12/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
.....

@rlcauvin
Copy link

@fabriciorsf Thanks for clarifying how you're doing multiprocessing. I guess the CompoundEarlyStopping class is not thread safe. Perhaps someone else can suggest how to make it thread safe, whether by using threading.Lock(), threading.local(), or some other means.

@fabriciorsf
Copy link
Author

Hi @fabriciorsf

As @rlcauvin mentioned, you can achieve this by creating a custom callback. Closing this issue. Please feel free to re-open if that's not what you're looking for.

Hi @SamanehSaadat, I don't have permissions to reopen this issue.

@SamanehSaadat SamanehSaadat reopened this Mar 12, 2025
@SamanehSaadat
Copy link
Member

Have you tested switching to the JAX backend? (JAX is thread-safe but Tensorflow is not)

@SamanehSaadat SamanehSaadat self-assigned this Mar 17, 2025
@SamanehSaadat SamanehSaadat removed the keras-team-review-pending Pending review by a Keras team member. label Mar 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

5 participants