Skip to content

Question regarding gradient checkpointing #74

@Dxyk

Description

@Dxyk

Hello,

I am trying to understand gradient checkpointing and found your explanation in gradient-checkpointing-nin.ipynb very helpful. I cloned the repo and tried rerunning the experiments. However, I was unable to reproduce the result mentioned in your conclusion.

When I run the notebook, for the vanilla NiN, my memory consumption (current, peak) are 413527 and 154049604, with runtime 109.1s.
For the checkpointed version (segments=1) of the model, the memory consumption are 402938 and 154064699, with runtime 110.14s.
From these tests, I was not able to observe a significant improvement in memory as the notebook states (22% memory improvement with 14% runtime sacrifice).

I've tried running with multiple seeds and checkpoint segment sizes, and was not able to see a significant memory improvement either.

I'm not sure why this is and could need a bit of help. Could this be due to the size of the network is relatively small and the effects are less obvious? Or could it be the checkpointing implementation from PyTorch has changed over the years? I would appreciate it if you could provide any insight in this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions