-
Notifications
You must be signed in to change notification settings - Fork 19.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EarlyStopping with list of metrics to monitor #21001
Comments
One factor to consider when using multiple metrics for early stopping: how do you determine which epoch is "best" and should be used for restoring best weights (when |
The best epoch will be the one with the best |
Another consideration is that we want to minimize some metrics (e.g. It might be instructive to look at this implementation of a composite of early stopping callbacks. The
Here is an example of how I've used the
|
Hi @fabriciorsf As @rlcauvin mentioned, you can achieve this by creating a custom callback. |
I tried this solution, but I had some problems:
...
Epoch 5/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 98ms/step - loss: 1.3509 - sumsqeuc_dist: 1.3499
Epoch 5: saving model to ./saved_models/my_model_epoch_005.keras
23/23 ━━━━━━━━━━━━━━━━━━━━ 5s 185ms/step - loss: 1.3383 - sumsqeuc_dist: 1.3369 - val_loss: 0.7643 - val_sumsqeuc_dist: 0.7621
Traceback (most recent call last):
Traceback (most recent call last):
...
File "myscript.py", line 293, in train_autonem
history = self.my_model.fit(self.train_inputs,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "....../python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "..../mypackage/utils/keras_extensions.py", line 268, in on_train_end
conductor = self.callbacks[-1]
~~~~~~~~~~~~~~^^^^
TypeError: 'dict_values' object is not subscriptable The class class CompoundEarlyStopping(keras.callbacks.Callback):
def __init__(
self, callbacks: Iterable[Callback],
iterable_condition: Callable[[], bool] = all):
super().__init__()
self.callbacks = callbacks
self.stopped_epoch = 0
self.iterable_condition = iterable_condition
def on_train_begin(self, logs: dict = None):
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_train_end(self, logs: dict = None):
if self.model.stop_training:
conductor = self.callbacks[-1] ## intermittent ERROR here
conductor.stopped_epoch = self.stopped_epoch
conductor.on_train_end(logs)
def on_epoch_begin(self, epoch: int, logs: dict = None):
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
def on_epoch_end(self, epoch: int, logs: dict = None):
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
if self.iterable_condition([callback.stopped_epoch >= max(1, epoch) \
for callback in self.callbacks]):
self.stopped_epoch = epoch
self.model.stop_training = True
else:
self.model.stop_training = False
def set_model(self, model: Model):
super().set_model(model)
for callback in self.callbacks:
callback.set_model(model) Note: There is a error with I instantiate the callbacks like this: dict_early_stopping = {val_metric_name: EarlyStopping(monitor=val_metric_name, patience=5,
start_from_epoch=1, verbose=self.verbose) \
for val_metric_name in val_metrics_name}
early_stopping = CompoundEarlyStopping(
callbacks=dict_early_stopping.values(),
iterable_condition=all) As a baseline, doing |
@fabriciorsf The original implementation of I have edited the I changed
to
Let us know if it works for you. |
I tested it, and the second problem doesn't occur, but the first problem remains: Epoch 1/100
23/23 ━━━━━━━━━━━━━━━━━━━━ 0s 100ms/step - loss: 1.3480 - sumsqeuc_dist: 1.3470 And it stays frozen like that. |
I also tested without multithreading and everything worked fine. So I suspect there is some problem with the For the record: without the |
@fabriciorsf What do you mean by "multithreading" in this context? |
It means using a Ctrl+C while freezing:
|
@fabriciorsf Thanks for clarifying how you're doing multiprocessing. I guess the |
Hi @SamanehSaadat, I don't have permissions to reopen this issue. |
Have you tested switching to the JAX backend? (JAX is thread-safe but Tensorflow is not) |
In addition to this example:
callback = keras.callbacks.EarlyStopping(monitor='val_loss')
Allow monitoring of multiple metrics, as in this example:
callback = keras.callbacks.EarlyStopping(monitor=['val_loss', 'val_accuracy', 'val_f1measure'])
This way, training should not stop while any of these metrics get better values, not just one of them.
The text was updated successfully, but these errors were encountered: