Skip to content

The difference of Trainer.test with ddp strategy #21004

@code4luck

Description

@code4luck

Bug description

When i use ddp mode for training a model in 4 * 3090 (pytorchlightning==2.5.2 or 2.3) with 770M params(only 77M params are trainable). First, I use

trainer = pl.Trainer(
        accelerator="gpu",
        devices=4,
        strategy="ddp",
        num_nodes=args.num_nodes,
        max_epochs=args.max_epochs,
        accumulate_grad_batches=args.accumulate_grad_batches,
        gradient_clip_val=args.gradient_clip_val,
        gradient_clip_algorithm=args.gradient_clip_algorithm,
        deterministic=True,
        callbacks=[model_checkpoint, early_stop, metrics2file],
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
        log_every_n_steps=log_every_n_steps,
    )      
    model = LitModel.load_from_checkpoint(model_checkpoint.best_model_path)
    trainer.test(model, test_loader)
   # trainer.test(model_checkpoint.best_model_path, test_loader)

which report the OOM (and i find the gpu:0 is OOM, while others are correct), but when i use trainer.test(ckpt_path=model_checkpoint.best_model_path, dataloaders=test_loader) it will be fine. I don't what's the difference bewteen the two test methods?

What version are you seeing the problem on?

v2.5, V2.3

Reproduced in studio

No response

How to reproduce the bug

trainer = pl.Trainer(
        accelerator="gpu",
        devices=4,
        strategy="ddp",
        num_nodes=args.num_nodes,
        max_epochs=args.max_epochs,
        accumulate_grad_batches=args.accumulate_grad_batches,
        gradient_clip_val=args.gradient_clip_val,
        gradient_clip_algorithm=args.gradient_clip_algorithm,
        deterministic=True,
        callbacks=[model_checkpoint, early_stop, metrics2file],
        limit_train_batches=limit_train_batches,
        limit_val_batches=limit_val_batches,
        limit_test_batches=limit_test_batches,
        log_every_n_steps=log_every_n_steps,
    )      
    model = LitModel.load_from_checkpoint(model_checkpoint.best_model_path)
    trainer.test(model, test_loader)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions