Skip to content

Support context manager when toggling optimizers #17294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kaparoo opened this issue Apr 6, 2023 · 6 comments · May be fixed by #20771
Open

Support context manager when toggling optimizers #17294

kaparoo opened this issue Apr 6, 2023 · 6 comments · May be fixed by #20771
Labels
feature Is an improvement or enhancement help wanted Open to be worked on lightningmodule pl.LightningModule optimization
Milestone

Comments

@kaparoo
Copy link

kaparoo commented Apr 6, 2023

Description & Motivation

TL; DR

Replace this

self.toggle_optimizer(optimizer)
optimizer.zero_grad()
DO_SOMETHING()
optimizer.step()
self.untoggle_optimizer(optimizer)

To this

with self.toggle_optimizer(optimizer):
    DO_SOMETHING()

Description

Since PyTorch Lightning 2.0, manual optimization is possible, typically used for working with multiple optimizers in one LightningModule instance.

However, if you want to use n number of optimizers in order, you will have to continue writing the boilerplate codes:

# assume that self.automatic_optimization = False
def training_step(self, ...):
    opt1, opt2, opt3, ... = self.optimizers()

    # following codes are repeated:
    # - self.toggle_optimizer(opt{i})
    # - opt{i}.zero_grad()
    # - opt{i}.step()
    # - self.untoggle_optimizer(opt{i})

    # for opt1
    self.toggle_optimizer(opt1)
    opt1.zero_grad()
    
    loss1 = self.compute_loss1(...)
    self.manual_backward(loss1)
    
    opt1.step()
    self.untoggle_optimizer(opt1)

    # for opt2
    self.toggle_optimizer(opt2)
    opt2.zero_grad()
    
    loss2 = self.compute_loss2(...)
    self.manual_backward(loss2)
    
    opt2.step()
    self.untoggle_optimizer(opt2)

    # for opt3
    self.toggle_optimizer(opt3)
    opt3.zero_grad()
    
    loss3 = self.compute_loss3(...)
    self.manual_backward(loss3)
    
    opt3.step()
    self.untoggle_optimizer(opt3)

    # ... repeat until n-th opt 

The above solution makes it difficult to focus on the logic and often leads to mistakes. Fortunately, however, Python offers the with statement to prevent this repetition and the mistakes that result from it.

For example, when you open a file in Python, you may use:

file = open("foo.txt", mode="r")
some_related_func(file)
file.close()  # <- you must close the file manually

But you can also use the more simple way:

with open("foo.txt", mode="r") as file:
    some_related_func(file)
    # when the block is ended, file.close() is called automatically

Likewise, if the boilerplate codes for the optimizer are automatically and implicitly managed in LightningModule, the readability, and quality of the entire code will be much better.

Pitch

Suggestion

There are two major changes in LightningModule.toggle_optimizer():

  1. let toggle_optimizer() return context manager that has two magic methods: __enter__, and __exit__.
  2. let toggle_optimizer() take a boolean parameter zero_grad (default: True).
    This is because a model can be linked to multiple optimizers and vice versa.
    For such cases, automatic zero_grad should be left as an option for the programmer.

NOTE Defining a new method might be a better approach than modifying toggle_optimizer. See the Alternatives section.

Solution

CAUTION I didn't check whether this solution has any problem yet.

def toggle_optimizer(self, optimizer, zero_grad: bool = True) -> OptimizerToggle:
    return OptimizerToggle(self, optimizer, zero_grad)

# The name is just one suggestion :)
class OptimizerToggle: 
    def __init__(self, module, optimizer, zero_grad: bool = True):
        self.module = module
        self.optimizer = optimizer
        self.zero_grad = zero_grad

    # Executed when entering `with` statement
    def __enter__(self):

        ORIGINAL_TOGGLE_OPTIMIZER_LOGIC(self.module, self.optimizer)  # just an example

        if self.zero_grad:
            self.optimizer.zero_grad()

    # Executed when exiting block of the `with` statement
    # Parameters are set for any Exception occurred in the block
    def  __exit__(self, type, value, traceback):

        EXCEPTION_RELATED_LOGIC(type, value, traceback)  # also, just an example 

        self.optimizer.step()
        self.module.untoggle_optimizer(optimizer)

Then, we can use the with statement as follow:

# The below line is equivalent to `with self.toggle_optimizer(optimizer, zero_grad=True):`
with self.toggle_optimizer(optimizer):
    # 1. self.toggle_optimizer(optimizer) creates a context manager
    #     with parameters `optimizer` and `zero_grad` (+ LightningModule instance: `self`).
    #     Also, any process defined in the original `toggle_optimizer()` is executed.
    # 2. If the `zero_grad` is True, the context manager calls
    #     `optimizer.zero_grad()` automatically.

    loss = LOSS_RELATED_FUNC(...)
    self.manual_backward(loss)

    # 3. When the context manager is closed, (i.e., calls `__exit__`)
    #     following codes are executed automatically:
    #     - optimizer.step()
    #     - self.untoggle_optimizer(optimizer)

Use case (GAN)

Without suggestion (current approach):

class GAN(LightningModule):
    def __init__(self, ...):
        super().__init__()
        self.automatic_optimization = False
        self.generator = Generator(...)
        self.discriminator = Discriminator(...)
        ...

    ...

    def training_step(self, batch):
        x, y_real = batch
        y_fake = self.generator(x)

        d_opt, g_opt = self.optimizers()

        self.toggle_optimizer(d_opt):
        d_opt.zero_grad()
        p_real = self.discriminator(y_real)
        p_fake = self.discriminator(y_fake.detach())
        loss_real = self.adversarial_loss(p_real, type="real")
        loss_fake = self.adversarial_loss(p_fake, type="fake")
        loss = loss_real + loss_fake
        self.manual_backward(loss)
        d_opt.step()
        self.untoggle_optimizer(d_opt)

        self.toggle_optimizer(g_opt):
        self.generator.zero_grad()  # equivalent to g_opt.zero_grad()
        p = self.discriminator(y_fake)
        loss = self.adversarial_loss(p, type="real")
        self.manual_backward(loss)
        g_opt.step()
        self.untoggle_optimizer(g_opt)

With suggestion (newer approach):

class GAN(LightningModule):
    def __init__(self, ...):
        super().__init__()
        self.automatic_optimization = False
        self.generator = Generator(...)
        self.discriminator = Discriminator(...)
        ...

    ...

    def training_step(self, batch):
        x, y_real = batch
        y_fake = self.generator(x)

        d_opt, g_opt = self.optimizers()

        with self.toggle_optimizer(d_opt):
            p_real = self.discriminator(y_real)
            p_fake = self.discriminator(y_fake.detach())
            loss_real = self.adversarial_loss(p_real, type="real")
            loss_fake = self.adversarial_loss(p_fake, type="fake")
            d_loss = loss_real + loss_fake
            self.manual_backward(d_loss)

        with self.toggle_optimizer(g_opt, zero_grad=False):
            # We can still call `zero_grad` manually :)
            self.generator.zero_grad()
            p = self.discriminator(y_fake)
            g_loss = self.adversarial_loss(p, type="real")
            self.manual_backward(g_loss)

        # Besides, we can access variables outside the statements
        self.log("d_loss", d_loss, prog_bar=True)
        self.log("g_loss", g_loss, prog_bar=True)

Alternatives

New method

For several reasons, including compatibility, it may be a bad idea to modify the existing functionality of a method.
Therefore, I also propose an alternative approach that calls toggle_optimizer() internally.

def alternative_approach(self, optimizer, zero_grad: bool = True) -> OptimizerToggle:
    return OptimizerToggle(self, optimizer, zero_grad)

class OptimizerToggle: 
    def __init__(self, module, optimizer, zero_grad: bool = True):
        self.module = module
        self.optimizer = optimizer
        self.zero_grad = zero_grad

    def __enter__(self):
        self.module.toggle_optimizer(self.optimizer)
        if self.zero_grad:
            self.optimizer.zero_grad()

    def  __exit__(self, type, value, traceback):
        EXCEPTION_RELATED_LOGIC(type, value, traceback)
        self.optimizer.step()
        self.module.untoggle_optimizer(optimizer)

Candidate names for the new method

  1. session
    Inspired by tf.Session() of TensorFlow v1

    with self.session(optimizer):
        pass
  2. focus_optimizer
    Because focusing on a certain optimizer is what we actually want to do.

    with self.focus_optimizer(optimizer):
        pass
  3. switch_optimizer
    One of the synonyms of toggle

    with self.switch_optimizer(optimizer):
        pass
  4. Other suggestions are welcomed :)

Additional context

No response

cc @Borda @carmocca @justusschock @awaelchli

@kaparoo kaparoo added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Apr 6, 2023
@awaelchli
Copy link
Contributor

I think this is a very reasonable request :) I'm in favor. The untoggling could be easily forgotten, and with the context manager it is one less thing to think about.

@awaelchli awaelchli added lightningmodule pl.LightningModule optimization and removed needs triage Waiting to be triaged by maintainers labels Apr 7, 2023
@awaelchli awaelchli added this to the 2.1 milestone Apr 7, 2023
@carmocca
Copy link
Contributor

I'm in favor of the context manager option to toggle and untoggle. I suggest naming it with toggled_optimizer(opt)

But I don't think it should zero_grad and step automatically, in my eyes it goes against manual optimization design. One could also then want to call manual_backward inside and then it's a strange hybrid in between manual and automatic optimization.

@kaparoo
Copy link
Author

kaparoo commented Apr 12, 2023

Thank you for your valuable comments @awaelchli @carmocca !

I also don't think it's right to force zero_grad and step to run only automatically. As mentioned in the issue, this is because modules and optimizers generally have a many-to-many relationship. However, I think it's worthwhile to design the manager so that the user can choose whether to perform those two processes automatically.

Therefore, the new design will be:

def toggled_optimizer(self, optimizer: Optimizer, auto_zero: bool = True, auto_step: bool = True) -> OptimizerToggle:
    return OptimizerToggle(self, optimizer, auto_zero, auto_step)

class OptimizerToggle(object):
    def __init__(self, module: LightningModule, optimizer: Optimizer, auto_zero: bool = True, auto_step: bool = True) -> None:
        self.module= module
        self.optimizer = optimizer
        self.auto_zero = auto_zero
        self.auto_step = auto_step

    def __enter__(self) -> None:
        ...
        self.module.toggle_optimizer(self.optimizer)
        
        # if not need to set the initial grad to zero (the default is `None`)
        if self.auto_zero:
            self.optimizer.zero_grad()
    
    def __exit__(self, ...) -> None:
        ...
        
        # if not need any specific closure or argument for optimization
        if self.auto_step:
            self.optimizer.step()
        
        self.module.untoggle_optimizer(self.optimizer)

Reason 1

Because, at least in my experience, even though the two methods zero_grad and step each have arguments, It is a major case to call it without change (e.g., zero_grad(), not zero_grap(set_to_none=False)), Therefore, even if the context manager calls these two methods internally, there will be no meaningful side effects. This design will be helpful for users who want to focus on loss computation without additional modifications to the optimization process, and even if modifications are required, explicit declarations such as with toggled_optimizer(opt, auto_zero=False, auto_step=False) (1) will prevent users from misunderstanding.

Reason 2

Moreover, even if the context manager is used, manual_backward is always called between zero_grad and step (2), and the user explicitly sets with toggled_optimizer(opt) instead of using the optimization provided by Trainer. So, I think this is still a manual operation.

Supplement

with self.toggled_optimizer(opt, auto_zero=False, auto_step=False):   # ---- (1)
    opt.zero_grad(set_to_none=False)

    loss = LOSS_COMPUTATION(...)
    self.manual_backward(loss)    # ---- (2)

    opt.step(CLOSURE, **KWARGS)

or (if not need any modification for the optimization)

with self.toggled_optimizer(opt):
    # OptimizerToggle runs `opt.zero_grad()` with `set_to_none=True`

    loss = LOSS_COMPUTATION(...)
    self.manual_backward(loss)    # ---- (2)

    # OptimizerToggle runs `opt.step()` without any closure and argument

@kaparoo
Copy link
Author

kaparoo commented Jun 4, 2023

Hello! It's already been about two months since I wrote this issue :)
If there are no additional comments, can I submit a pull request related to this issue?
I am a graduate student, and so far I have not been able to focus on this issue due to coursework.
However, I will be able to propose a draft in July.

@carmocca
Copy link
Contributor

carmocca commented Jun 5, 2023

The context manager usability improvement is a good idea. But I would still not add the optional extra arguments. The reason is that it adds one more different way to do things that only removes two lines of code. It takes away from the expressivity of PyTorch autograd notation.

Feel free to submit the PR. Other reviewers might have a different view.

@awaelchli
Copy link
Contributor

I agree, the context manager is nice but it should only toggle. It's a nice feature because you can't forget to "untoggle". I think it should not do anything extra beyond that, because in manual optimization, control of step and zero grad is a feature, not a burden. IMO, if automation is desired, one should choose automatic optimization from the start.

@awaelchli awaelchli modified the milestones: 2.1, future Aug 29, 2023
@awaelchli awaelchli added the help wanted Open to be worked on label Sep 20, 2023
@rustamzh rustamzh linked a pull request Apr 28, 2025 that will close this issue
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on lightningmodule pl.LightningModule optimization
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants