Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLFlow logger - save top k to server on N epochs #20584

Open
HarryAnkers opened this issue Feb 11, 2025 · 3 comments · May be fixed by #20585
Open

MLFlow logger - save top k to server on N epochs #20584

HarryAnkers opened this issue Feb 11, 2025 · 3 comments · May be fixed by #20585
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@HarryAnkers
Copy link

HarryAnkers commented Feb 11, 2025

Description & Motivation

Hey,
As far as I can tell, if you want to continuously save .ckpt files to an MLflow server during a training run, the best approach is to use the MLFlowLogger flag log_model="all".

However, this comes with two issues:

This significantly increases storage use in the artifact store.
It contradicts the ModelCheckpoint flag save_top_k.
If save_top_k is set, it only retains these checkpoints locally.
If the experiment crashes mid-run, these checkpoints are lost.
Without guaranteed local persistence, this isn't ideal for long-running or cloud-based training workflows.

Pitch

A new feature allowing top-k checkpoints to be upserted in MLflow and other loggers would be incredibly useful.

Proposed behavior:

save_top_k=2 is set.
A new checkpoint is created → It is upserted in the logger.
Another checkpoint is created → It is upserted in the logger.
A new, better checkpoint replaces an old one → It is upserted, and a previous one is deleted.
Proposed Implementation
This logic could be integrated here, where files are already being removed locally. The change would involve adding logger-specific removal functionality in addition to the local deletion.

Would love to hear thoughts on this!

Alternatives

No response

Additional context

Here is a script to demonstrate the pain:

import os
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Change these as wanted
os.environ["MLFLOW_TRACKING_USERNAME"] = ""
os.environ["MLFLOW_TRACKING_PASSWORD"] = ""


class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss)
        if self.current_epoch == 34:
            raise Exception("Forced failure at epoch 35")
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return optim.SGD(self.parameters(), lr=0.01)


def train():
    # If local use this uri
    URI = "http://localhost:5000"

    mlflow_logger = MLFlowLogger(
        experiment_name="harry-test",
        tracking_uri=URI,
        log_model=True,
    )

    x_train, y_train = torch.randn(100, 10), torch.randn(100, 1)
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=32)

    x_val, y_val = torch.randn(20, 10), torch.randn(20, 1)
    val_dataset = TensorDataset(x_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=32)

    model = SimpleModel()

    checkpoint_callback_train = ModelCheckpoint(
        monitor="train_loss",
        filename="best_train_model-{epoch:02d}-{train_loss:.2f}",
        save_top_k=2,
        mode="min",
    )

    checkpoint_callback_val = ModelCheckpoint(
        monitor="val_loss",
        filename="best_val_model-{epoch:02d}-{val_loss:.2f}",
        save_top_k=2,
        mode="min",
    )

    trainer = pl.Trainer(
        max_epochs=40,
        logger=mlflow_logger,
        callbacks=[checkpoint_callback_train, checkpoint_callback_val],
        val_check_interval=3,
    )

    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    train()

cc @lantiga @Borda

@HarryAnkers HarryAnkers added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Feb 11, 2025
@HarryAnkers
Copy link
Author

I think I may be wrong with my diagnosis. I think my issue may be something with the fact I have multiple top_k checkpoints

@HarryAnkers
Copy link
Author

Wrote a pr to fix what I think may be unintended behaviour:
#20585

@msdkhairi
Copy link

I think I may be wrong with my diagnosis. I think my issue may be something with the fact I have multiple top_k checkpoints

I don't think the issue is caused by multiple ModelCheckpoint instances. I'm using a single ModelCheckpoint and MLFlowLogger, yet I'm still experiencing the issue you initially mentioned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants