-
Notifications
You must be signed in to change notification settings - Fork 502
[Minor] Enable continuation of training #1605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
2ae4506
900c8d5
f1355eb
da3a6d5
492dee9
f996928
f9a77f8
7ad761d
b14d20b
9fe3401
00f2e25
5f103d8
e043201
df74dc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -978,7 +978,7 @@ def fit( | |
pd.DataFrame | ||
metrics with training and potentially evaluation metrics | ||
""" | ||
if self.fitted: | ||
if self.fitted and not continue_training: | ||
raise RuntimeError("Model has been fitted already. Please initialize a new model to fit again.") | ||
|
||
# Configuration | ||
|
@@ -1067,6 +1067,10 @@ def fit( | |
|
||
if self.fitted is True and not continue_training: | ||
log.error("Model has already been fitted. Re-fitting may break or produce different results.") | ||
|
||
if continue_training and self.metrics_logger.checkpoint_path is None: | ||
log.error("Continued training requires checkpointing in model.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please explain what necessitates this (for my understanding)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My thinking was that it makes sense to continue from the checkpoint, but probably it's not necessary. All the necessary parameters should still be available in the model itself. I will adapt it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did some testing and it seems like checkpoint is indeed necessary to correctly continue training with the pytroch-lighting trainer. I can get it to run without checkpointing but fitting again always leads to a complete restart of the training. Maybe there is some workaround, but I would suggest keeping it like this as continued training always goes hand in hand with checkpointing in pytroch-lighting. |
||
|
||
self.max_lags = df_utils.get_max_num_lags( | ||
n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors | ||
) | ||
|
@@ -2666,23 +2670,24 @@ def _init_train_loader(self, df, num_workers=0): | |
torch DataLoader | ||
""" | ||
df, _, _, _ = df_utils.prep_or_copy_df(df) # TODO: Can this call be avoided? | ||
# if not self.fitted: | ||
self.config_normalization.init_data_params( | ||
df=df, | ||
config_lagged_regressors=self.config_lagged_regressors, | ||
config_regressors=self.config_regressors, | ||
config_events=self.config_events, | ||
config_seasonality=self.config_seasonality, | ||
) | ||
if not self.fitted: | ||
self.config_normalization.init_data_params( | ||
df=df, | ||
config_lagged_regressors=self.config_lagged_regressors, | ||
config_regressors=self.config_regressors, | ||
config_events=self.config_events, | ||
config_seasonality=self.config_seasonality, | ||
) | ||
|
||
print("Changepoints:", self.config_trend.changepoints) | ||
df = _normalize(df=df, config_normalization=self.config_normalization) | ||
# if not self.fitted: | ||
if self.config_trend.changepoints is not None: | ||
# scale user-specified changepoint times | ||
df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) | ||
if not self.fitted: | ||
if self.config_trend.changepoints is not None: | ||
# scale user-specified changepoint times | ||
df_aux = pd.DataFrame({"ds": pd.Series(self.config_trend.changepoints)}) | ||
|
||
df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) | ||
self.config_trend.changepoints = df_normalized["t"].values # type: ignore | ||
df_normalized = _normalize(df=df_aux, config_normalization=self.config_normalization) | ||
self.config_trend.changepoints = df_normalized["t"].values # type: ignore | ||
|
||
# df_merged, _ = df_utils.join_dataframes(df) | ||
# df_merged = df_merged.sort_values("ds") | ||
|
@@ -2770,12 +2775,36 @@ def _train( | |
# Internal flag to check if validation is enabled | ||
validation_enabled = df_val is not None | ||
|
||
# Init the model, if not continue from checkpoint | ||
# Load model and optimizer state from checkpoint if continue_training is True | ||
if continue_training: | ||
raise NotImplementedError( | ||
"Continuing training from checkpoint is not implemented yet. This feature is planned for one of the \ | ||
upcoming releases." | ||
) | ||
checkpoint_path = self.metrics_logger.checkpoint_path | ||
checkpoint = torch.load(checkpoint_path) | ||
|
||
# Load model state | ||
self.model.load_state_dict(checkpoint["state_dict"]) | ||
|
||
# Set continue_training flag in model to update scheduler correctly | ||
self.model.continue_training = True | ||
|
||
previous_epoch = checkpoint["epoch"] | ||
# Adjust epochs | ||
if self.config_train.epochs: | ||
additional_epochs = self.config_train.epochs | ||
else: | ||
additional_epochs = previous_epoch | ||
weberpals marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Get the number of epochs already trained | ||
new_total_epochs = previous_epoch + additional_epochs | ||
self.config_train.epochs = new_total_epochs | ||
|
||
# Reinitialize optimizer with loaded model parameters | ||
optimizer = torch.optim.AdamW(self.model.parameters()) | ||
weberpals marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Load optimizer state | ||
if "optimizer_states" in checkpoint and checkpoint["optimizer_states"]: | ||
optimizer.load_state_dict(checkpoint["optimizer_states"][0]) | ||
|
||
self.config_train.optimizer = optimizer | ||
|
||
else: | ||
self.model = self._init_model() | ||
|
||
|
@@ -2859,8 +2888,12 @@ def _train( | |
|
||
if not metrics_enabled: | ||
return None | ||
|
||
# Return metrics collected in logger as dataframe | ||
metrics_df = pd.DataFrame(self.metrics_logger.history) | ||
if self.metrics_logger.history is not None: | ||
metrics_df = pd.DataFrame(self.metrics_logger.history) | ||
else: | ||
metrics_df = pd.DataFrame() | ||
return metrics_df | ||
|
||
def restore_trainer(self, accelerator: Optional[str] = None): | ||
|
Uh oh!
There was an error while loading. Please reload this page.