Skip to content

Support context manager when toggling optimizers #17294

Closed
@kaparoo

Description

@kaparoo

Description & Motivation

TL; DR

Replace this

self.toggle_optimizer(optimizer)
optimizer.zero_grad()
DO_SOMETHING()
optimizer.step()
self.untoggle_optimizer(optimizer)

To this

with self.toggle_optimizer(optimizer):
    DO_SOMETHING()

Description

Since PyTorch Lightning 2.0, manual optimization is possible, typically used for working with multiple optimizers in one LightningModule instance.

However, if you want to use n number of optimizers in order, you will have to continue writing the boilerplate codes:

# assume that self.automatic_optimization = False
def training_step(self, ...):
    opt1, opt2, opt3, ... = self.optimizers()

    # following codes are repeated:
    # - self.toggle_optimizer(opt{i})
    # - opt{i}.zero_grad()
    # - opt{i}.step()
    # - self.untoggle_optimizer(opt{i})

    # for opt1
    self.toggle_optimizer(opt1)
    opt1.zero_grad()
    
    loss1 = self.compute_loss1(...)
    self.manual_backward(loss1)
    
    opt1.step()
    self.untoggle_optimizer(opt1)

    # for opt2
    self.toggle_optimizer(opt2)
    opt2.zero_grad()
    
    loss2 = self.compute_loss2(...)
    self.manual_backward(loss2)
    
    opt2.step()
    self.untoggle_optimizer(opt2)

    # for opt3
    self.toggle_optimizer(opt3)
    opt3.zero_grad()
    
    loss3 = self.compute_loss3(...)
    self.manual_backward(loss3)
    
    opt3.step()
    self.untoggle_optimizer(opt3)

    # ... repeat until n-th opt 

The above solution makes it difficult to focus on the logic and often leads to mistakes. Fortunately, however, Python offers the with statement to prevent this repetition and the mistakes that result from it.

For example, when you open a file in Python, you may use:

file = open("foo.txt", mode="r")
some_related_func(file)
file.close()  # <- you must close the file manually

But you can also use the more simple way:

with open("foo.txt", mode="r") as file:
    some_related_func(file)
    # when the block is ended, file.close() is called automatically

Likewise, if the boilerplate codes for the optimizer are automatically and implicitly managed in LightningModule, the readability, and quality of the entire code will be much better.

Pitch

Suggestion

There are two major changes in LightningModule.toggle_optimizer():

  1. let toggle_optimizer() return context manager that has two magic methods: __enter__, and __exit__.
  2. let toggle_optimizer() take a boolean parameter zero_grad (default: True).
    This is because a model can be linked to multiple optimizers and vice versa.
    For such cases, automatic zero_grad should be left as an option for the programmer.

NOTE Defining a new method might be a better approach than modifying toggle_optimizer. See the Alternatives section.

Solution

CAUTION I didn't check whether this solution has any problem yet.

def toggle_optimizer(self, optimizer, zero_grad: bool = True) -> OptimizerToggle:
    return OptimizerToggle(self, optimizer, zero_grad)

# The name is just one suggestion :)
class OptimizerToggle: 
    def __init__(self, module, optimizer, zero_grad: bool = True):
        self.module = module
        self.optimizer = optimizer
        self.zero_grad = zero_grad

    # Executed when entering `with` statement
    def __enter__(self):

        ORIGINAL_TOGGLE_OPTIMIZER_LOGIC(self.module, self.optimizer)  # just an example

        if self.zero_grad:
            self.optimizer.zero_grad()

    # Executed when exiting block of the `with` statement
    # Parameters are set for any Exception occurred in the block
    def  __exit__(self, type, value, traceback):

        EXCEPTION_RELATED_LOGIC(type, value, traceback)  # also, just an example 

        self.optimizer.step()
        self.module.untoggle_optimizer(optimizer)

Then, we can use the with statement as follow:

# The below line is equivalent to `with self.toggle_optimizer(optimizer, zero_grad=True):`
with self.toggle_optimizer(optimizer):
    # 1. self.toggle_optimizer(optimizer) creates a context manager
    #     with parameters `optimizer` and `zero_grad` (+ LightningModule instance: `self`).
    #     Also, any process defined in the original `toggle_optimizer()` is executed.
    # 2. If the `zero_grad` is True, the context manager calls
    #     `optimizer.zero_grad()` automatically.

    loss = LOSS_RELATED_FUNC(...)
    self.manual_backward(loss)

    # 3. When the context manager is closed, (i.e., calls `__exit__`)
    #     following codes are executed automatically:
    #     - optimizer.step()
    #     - self.untoggle_optimizer(optimizer)

Use case (GAN)

Without suggestion (current approach):

class GAN(LightningModule):
    def __init__(self, ...):
        super().__init__()
        self.automatic_optimization = False
        self.generator = Generator(...)
        self.discriminator = Discriminator(...)
        ...

    ...

    def training_step(self, batch):
        x, y_real = batch
        y_fake = self.generator(x)

        d_opt, g_opt = self.optimizers()

        self.toggle_optimizer(d_opt):
        d_opt.zero_grad()
        p_real = self.discriminator(y_real)
        p_fake = self.discriminator(y_fake.detach())
        loss_real = self.adversarial_loss(p_real, type="real")
        loss_fake = self.adversarial_loss(p_fake, type="fake")
        loss = loss_real + loss_fake
        self.manual_backward(loss)
        d_opt.step()
        self.untoggle_optimizer(d_opt)

        self.toggle_optimizer(g_opt):
        self.generator.zero_grad()  # equivalent to g_opt.zero_grad()
        p = self.discriminator(y_fake)
        loss = self.adversarial_loss(p, type="real")
        self.manual_backward(loss)
        g_opt.step()
        self.untoggle_optimizer(g_opt)

With suggestion (newer approach):

class GAN(LightningModule):
    def __init__(self, ...):
        super().__init__()
        self.automatic_optimization = False
        self.generator = Generator(...)
        self.discriminator = Discriminator(...)
        ...

    ...

    def training_step(self, batch):
        x, y_real = batch
        y_fake = self.generator(x)

        d_opt, g_opt = self.optimizers()

        with self.toggle_optimizer(d_opt):
            p_real = self.discriminator(y_real)
            p_fake = self.discriminator(y_fake.detach())
            loss_real = self.adversarial_loss(p_real, type="real")
            loss_fake = self.adversarial_loss(p_fake, type="fake")
            d_loss = loss_real + loss_fake
            self.manual_backward(d_loss)

        with self.toggle_optimizer(g_opt, zero_grad=False):
            # We can still call `zero_grad` manually :)
            self.generator.zero_grad()
            p = self.discriminator(y_fake)
            g_loss = self.adversarial_loss(p, type="real")
            self.manual_backward(g_loss)

        # Besides, we can access variables outside the statements
        self.log("d_loss", d_loss, prog_bar=True)
        self.log("g_loss", g_loss, prog_bar=True)

Alternatives

New method

For several reasons, including compatibility, it may be a bad idea to modify the existing functionality of a method.
Therefore, I also propose an alternative approach that calls toggle_optimizer() internally.

def alternative_approach(self, optimizer, zero_grad: bool = True) -> OptimizerToggle:
    return OptimizerToggle(self, optimizer, zero_grad)

class OptimizerToggle: 
    def __init__(self, module, optimizer, zero_grad: bool = True):
        self.module = module
        self.optimizer = optimizer
        self.zero_grad = zero_grad

    def __enter__(self):
        self.module.toggle_optimizer(self.optimizer)
        if self.zero_grad:
            self.optimizer.zero_grad()

    def  __exit__(self, type, value, traceback):
        EXCEPTION_RELATED_LOGIC(type, value, traceback)
        self.optimizer.step()
        self.module.untoggle_optimizer(optimizer)

Candidate names for the new method

  1. session
    Inspired by tf.Session() of TensorFlow v1

    with self.session(optimizer):
        pass
  2. focus_optimizer
    Because focusing on a certain optimizer is what we actually want to do.

    with self.focus_optimizer(optimizer):
        pass
  3. switch_optimizer
    One of the synonyms of toggle

    with self.switch_optimizer(optimizer):
        pass
  4. Other suggestions are welcomed :)

Additional context

No response

cc @Borda @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions