-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
I got the following error when trying to restore training state from ckpt saved in training step.
However when I switch to the default prog bar (tqpm), I can resume training normally from ckpt file.
.../python3.10/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", line 452, in _update
assert progress_bar_id is not None
AssertionError
This is my config of Training model checkpoint
from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from datetime import timedelta
import torch
from copy import deepcopy
from typing_extensions import override
class TrainingCheckpoint(ModelCheckpoint):
def __init__(self, last_n=1, every_n_minites: int = None, every_n_iterations: int =None, last_name: str = 'training-last'):
monitor = "totalsteps"
super().__init__(
filename='training-{epoch}-{step}-{monitor}'.replace("monitor", monitor),
monitor=monitor,
mode="max",
save_last=True,
save_top_k=last_n,
every_n_train_steps = every_n_iterations,
save_weights_only=False,
train_time_interval=timedelta(minutes=every_n_minites),
save_on_train_epoch_end = False,
)
self.monitor = None
self.CHECKPOINT_NAME_LAST = last_name
@override
def _monitor_candidates(self, trainer: "pl.Trainer") -> dict[str, Tensor]:
monitor_candidates = deepcopy(trainer.callback_metrics)
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
# or does not exist we overwrite it as it's likely an error
epoch = monitor_candidates.get("epoch")
monitor_candidates["epoch"] = epoch.int() if isinstance(epoch, Tensor) else torch.tensor(trainer.current_epoch)
step = monitor_candidates.get("step")
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
monitor_candidates[self.monitor] = torch.tensor(trainer.global_step)
return monitor_candidates
@override
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
pass
This is my config of Rich progress bar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar, RichProgressBarTheme
from typing import List, Optional, Any
class CustomRichProgressBar(RichProgressBar):
def __init__(self):
refresh_rate: int = 1
leave: bool = False
theme: RichProgressBarTheme = RichProgressBarTheme(
description="green_yellow",
progress_bar="green1",
progress_bar_finished="green1",
progress_bar_pulse="#6206E0",
batch_progress="green_yellow",
time="grey82",
processing_speed="grey82",
metrics="grey82",
metrics_text_delimiter="\n",
metrics_format=".4f",
)
console_kwargs: Optional[dict[str, Any]] = None
super().__init__(refresh_rate, leave, theme, console_kwargs)
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x