Skip to content

Overfit batches parameter gives a validation batch #15021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
HekpoMaH opened this issue Oct 6, 2022 · 4 comments · May be fixed by #20731
Open

Overfit batches parameter gives a validation batch #15021

HekpoMaH opened this issue Oct 6, 2022 · 4 comments · May be fixed by #20731
Labels
bug Something isn't working help wanted Open to be worked on trainer: fit

Comments

@HekpoMaH
Copy link

HekpoMaH commented Oct 6, 2022

Bug description

When overfitting on a single batch and defining dataloaders in class, the batch provided to the validation step is different from the batch on the training step. I was told in the slack community that this is NOT the intended behaviour.

How to reproduce the bug

import pytorch_lightning as pl
import torch_geometric
import torch

dataset = [torch_geometric.data.Data(x=torch.tensor([i])) for i in range(10)]
val_dataset = [torch_geometric.data.Data(x=torch.tensor([j])) for j in range(10,20)]
class LitModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.tensor([0.]))
    def train_dataloader(self):
        return torch_geometric.loader.DataLoader(dataset, batch_size=2)
    def val_dataloader(self):
        return torch_geometric.loader.DataLoader(val_dataset, batch_size=2)

    def training_step(self, batch, batch_idx):
        print('train', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def validation_step(self, batch, batch_idx):
        print('val', batch.x)
        return torch.nn.functional.mse_loss(self.param,torch.tensor([1.]).to(self.param))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                               lr=.0001)
        return optimizer

litmod = LitModule()
trainer = pl.Trainer(
    overfit_batches=1,
    accelerator='cuda',
    max_epochs=20,
    check_val_every_n_epoch=10,
)
trainer.fit(litmod)
print(litmod)

The val batch is the [10,11] tensor, the train batch is the [0,1] tensor
image


### Environment

  • CUDA:
    • GPU:
      • NVIDIA GeForce RTX 3080 Laptop GPU
    • available: True
    • version: 11.6
  • Lightning:
    • pytorch-lightning: 1.7.7
    • torch: 1.12.1+cu116
    • torch-cluster: 1.6.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torch-sparse: 0.6.15
    • torch-spline-conv: 1.2.1
    • torchaudio: 0.12.1+cu116
    • torchmetrics: 0.10.0
    • torchvision: 0.13.1+cu116
  • Packages:
    • absl-py: 1.2.0
    • aiohttp: 3.8.3
    • aiosignal: 1.2.0
    • anndata: 0.8.0
    • astroid: 2.12.10
    • astunparse: 1.6.3
    • async-timeout: 4.0.2
    • attrs: 22.1.0
    • blinker: 1.4
    • brotlipy: 0.7.0
    • cachetools: 5.2.0
    • certifi: 2022.9.24
    • cffi: 1.15.1
    • charset-normalizer: 2.1.1
    • chex: 0.1.5
    • click: 8.0.4
    • colorama: 0.4.5
    • contourpy: 1.0.5
    • cryptography: 37.0.2
    • cycler: 0.11.0
    • dill: 0.3.5.1
    • distlib: 0.3.6
    • dm-clrs: 1.0.0
    • dm-haiku: 0.0.8
    • dm-tree: 0.1.7
    • etils: 0.8.0
    • filelock: 3.8.0
    • flatbuffers: 1.12
    • fonttools: 4.37.3
    • frozenlist: 1.3.1
    • fsspec: 2022.8.2
    • gast: 0.4.0
    • google-auth: 2.12.0
    • google-auth-oauthlib: 0.4.6
    • google-pasta: 0.2.0
    • googleapis-common-protos: 1.56.4
    • grpcio: 1.49.1
    • h5py: 3.7.0
    • idna: 3.4
    • importlib-metadata: 4.11.4
    • importlib-resources: 5.9.0
    • isort: 5.10.1
    • jax: 0.3.21
    • jaxlib: 0.3.20
    • jinja2: 3.1.2
    • jmp: 0.0.2
    • joblib: 1.2.0
    • jsonschema: 4.16.0
    • keras: 2.9.0
    • keras-preprocessing: 1.1.2
    • kiwisolver: 1.4.4
    • lazy-object-proxy: 1.7.1
    • libclang: 14.0.6
    • llvmlite: 0.39.1
    • markdown: 3.4.1
    • markupsafe: 2.1.1
    • matplotlib: 3.6.0
    • mccabe: 0.7.0
    • mkl-fft: 1.3.1
    • mkl-random: 1.2.2
    • mkl-service: 2.4.0
    • msgpack: 1.0.4
    • multidict: 6.0.2
    • natsort: 8.2.0
    • networkx: 2.8.6
    • numba: 0.56.2
    • numexpr: 2.8.3
    • numpy: 1.23.3
    • oauthlib: 3.2.1
    • opt-einsum: 3.3.0
    • optax: 0.1.3
    • packaging: 21.3
    • pandas: 1.5.0
    • patsy: 0.5.2
    • pillow: 9.2.0
    • pip: 22.1.2
    • platformdirs: 2.5.2
    • promise: 2.3
    • protobuf: 3.19.6
    • pyasn1: 0.4.8
    • pyasn1-modules: 0.2.8
    • pycparser: 2.21
    • pydeprecate: 0.3.2
    • pyjwt: 2.5.0
    • pylint: 2.15.3
    • pynndescent: 0.5.7
    • pyopenssl: 22.0.0
    • pyparsing: 3.0.9
    • pyrsistent: 0.18.1
    • pysocks: 1.7.1
    • python-dateutil: 2.8.2
    • pytorch-lightning: 1.7.7
    • pytz: 2022.2.1
    • pyu2f: 0.1.5
    • pyyaml: 6.0
    • ray: 2.0.0
    • requests: 2.28.1
    • requests-oauthlib: 1.3.1
    • rsa: 4.9
    • scanpy: 1.9.1
    • scikit-learn: 1.1.2
    • scikit-misc: 0.1.4
    • scipy: 1.9.1
    • seaborn: 0.12.0
    • session-info: 1.0.0
    • setuptools: 65.4.1
    • six: 1.16.0
    • statsmodels: 0.13.2
    • stdlib-list: 0.8.0
    • tables: 3.7.0
    • tabulate: 0.8.10
    • tensorboard: 2.9.1
    • tensorboard-data-server: 0.6.1
    • tensorboard-plugin-wit: 1.8.1
    • tensorboardx: 2.5.1
    • tensorflow: 2.9.1
    • tensorflow-estimator: 2.9.0
    • tensorflow-io-gcs-filesystem: 0.27.0
    • tensorflow-metadata: 1.10.0
    • termcolor: 2.0.1
    • tfds-nightly: 4.5.2.dev202204190046
    • threadpoolctl: 3.1.0
    • toml: 0.10.2
    • tomli: 2.0.1
    • tomlkit: 0.11.5
    • toolz: 0.12.0
    • torch: 1.12.1+cu116
    • torch-cluster: 1.6.0
    • torch-geometric: 2.1.0.post1
    • torch-scatter: 2.0.9
    • torch-sparse: 0.6.15
    • torch-spline-conv: 1.2.1
    • torchaudio: 0.12.1+cu116
    • torchmetrics: 0.10.0
    • torchvision: 0.13.1+cu116
    • tqdm: 4.64.1
    • typing-extensions: 4.3.0
    • umap-learn: 0.5.3
    • urllib3: 1.26.12
    • virtualenv: 20.16.5
    • werkzeug: 2.2.2
    • wheel: 0.37.1
    • wrapt: 1.14.1
    • yapf: 0.32.0
    • yarl: 1.8.1
    • zipp: 3.8.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.10.4
    • version: #202203181321-Ubuntu SMP PREEMPT Fri Mar 18 13:28:32 UTC 2022


### More info

_No response_

cc @justusschock @awaelchli
@HekpoMaH HekpoMaH added the needs triage Waiting to be triaged by maintainers label Oct 6, 2022
@awaelchli awaelchli added bug Something isn't working trainer: fit and removed needs triage Waiting to be triaged by maintainers labels Oct 9, 2022
@awaelchli awaelchli added this to the pl:1.7.x milestone Oct 9, 2022
@carmocca carmocca modified the milestones: pl:1.7.x, v1.8.x Oct 13, 2022
@Borda Borda modified the milestones: v1.8.x, v1.9 Jan 6, 2023
@Borda Borda modified the milestones: v1.9, v1.9.x Jan 16, 2023
@awaelchli awaelchli added the help wanted Open to be worked on label Dec 31, 2023
@awaelchli awaelchli removed this from the v1.9.x milestone Dec 31, 2023
@israfelsr
Copy link

I had the same problem. I was going crazy because in the documentation they supposed to be the same 😅.

@dgcnz
Copy link

dgcnz commented Jan 11, 2025

Same here

@nilsleh
Copy link

nilsleh commented Mar 13, 2025

@Borda
The latest documentation including the video snippet suggests that train_batch and val_batch will be identical but it seems that overfit_batches will use a "same" train_batch and a separate "same" val_batch instead.

@ved1beta ved1beta linked a pull request Apr 20, 2025 that will close this issue
7 tasks
@adosar
Copy link
Contributor

adosar commented Apr 28, 2025

@nilsleh Indeed the video snippet suggest that their identical which IMO should't, see #20731 (comment).

With regards to the documentation (version 2.5.1):

# default used by the Trainer
trainer = Trainer(overfit_batches=0.0)

# use only 1% of the train & val set
trainer = Trainer(overfit_batches=0.01)

# overfit on 10 of the same batches   <--- This seems confusing
trainer = Trainer(overfit_batches=10)

I think the last comment is a little bit confusing since it gives the idea that the same 10 batches are used for both training and validation. Maybe it should read as # overfit on 10 (same) train batches & 10 (same) val batches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on trainer: fit
Projects
No open projects
Status: No status
Development

Successfully merging a pull request may close this issue.

8 participants