Skip to content

Rich progress_bar_id is None if restore training state from a step checkpoint #21015

@5o1

Description

@5o1

Bug description

I got the following error when trying to restore training state from ckpt saved in training step.

However when I switch to the default prog bar (tqpm), I can resume training normally from ckpt file.

.../python3.10/site-packages/pytorch_lightning/callbacks/progress/rich_progress.py", line 452, in _update
    assert progress_bar_id is not None
AssertionError

This is my config of Training model checkpoint

from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from datetime import timedelta
import torch
from copy import deepcopy
from typing_extensions import override

class TrainingCheckpoint(ModelCheckpoint):
    def __init__(self, last_n=1, every_n_minites: int = None, every_n_iterations: int =None, last_name: str = 'training-last'):
        monitor = "totalsteps"
        super().__init__(
            filename='training-{epoch}-{step}-{monitor}'.replace("monitor", monitor),
            monitor=monitor,
            mode="max",
            save_last=True,
            save_top_k=last_n,
            every_n_train_steps = every_n_iterations,
            save_weights_only=False,
            train_time_interval=timedelta(minutes=every_n_minites),
            save_on_train_epoch_end = False,
        )
        self.monitor = None
        self.CHECKPOINT_NAME_LAST = last_name

    @override
    def _monitor_candidates(self, trainer: "pl.Trainer") -> dict[str, Tensor]:
        monitor_candidates = deepcopy(trainer.callback_metrics)
        # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
        # or does not exist we overwrite it as it's likely an error
        epoch = monitor_candidates.get("epoch")
        monitor_candidates["epoch"] = epoch.int() if isinstance(epoch, Tensor) else torch.tensor(trainer.current_epoch)
        step = monitor_candidates.get("step")
        monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
        monitor_candidates[self.monitor] = torch.tensor(trainer.global_step)
        return monitor_candidates
    
    @override
    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        pass

This is my config of Rich progress bar

from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar, RichProgressBarTheme
from typing import List, Optional, Any

class CustomRichProgressBar(RichProgressBar):
    def __init__(self):
        refresh_rate: int = 1
        leave: bool = False
        theme: RichProgressBarTheme = RichProgressBarTheme(
            description="green_yellow",
            progress_bar="green1",
            progress_bar_finished="green1",
            progress_bar_pulse="#6206E0",
            batch_progress="green_yellow",
            time="grey82",
            processing_speed="grey82",
            metrics="grey82",
            metrics_text_delimiter="\n",
            metrics_format=".4f",
        )
        console_kwargs: Optional[dict[str, Any]] = None

        super().__init__(refresh_rate, leave, theme, console_kwargs)

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions