diff --git a/avalanche/training/supervised/ser.py b/avalanche/training/supervised/ser.py new file mode 100644 index 000000000..fb53d0f7c --- /dev/null +++ b/avalanche/training/supervised/ser.py @@ -0,0 +1,207 @@ +from collections import defaultdict +from typing import ( + Callable, + List, + Optional, + Union, +) + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss, Module +from torch.optim import Optimizer +from copy import deepcopy + +from avalanche.training.templates.strategy_mixin_protocol import CriterionType +from avalanche.training.utils import cycle +from avalanche.core import SupervisedPlugin +from avalanche.training.plugins.evaluation import ( + EvaluationPlugin, + default_evaluator, +) +from avalanche.training.templates import SupervisedTemplate +from avalanche.training.supervised.der import ClassBalancedBufferWithLogits +from avalanche.models import avalanche_forward + + +class SER(SupervisedTemplate): + """ + Implements the SER Strategy, + from the "Continual Learning with Strong Experience Replay" + paper, https://arxiv.org/pdf/2305.13622 + """ + + def __init__( + self, + *, + model: Module, + optimizer: Optimizer, + criterion: CriterionType = CrossEntropyLoss(), + mem_size: int = 200, + batch_size_mem: Optional[int] = None, + alpha: float = 0.1, + beta: float = 0.5, + train_mb_size: int = 1, + train_epochs: int = 1, + eval_mb_size: Optional[int] = 1, + device: Union[str, torch.device] = "cpu", + plugins: Optional[List[SupervisedPlugin]] = None, + evaluator: Union[ + EvaluationPlugin, Callable[[], EvaluationPlugin] + ] = default_evaluator, + eval_every=-1, + peval_mode="epoch", + **kwargs + ): + """ + :param model: PyTorch model. + :param optimizer: PyTorch optimizer. + :param criterion: loss function. + :param mem_size: int : Fixed memory size + :param batch_size_mem: int : Size of the batch sampled from the buffer + :param alpha: float : Hyperparameter weighting the backward consistency loss + :param beta: float : Hyperparameter weighting the forward consistency loss + :param train_mb_size: mini-batch size for training. + :param train_passes: number of training passes. + :param eval_mb_size: mini-batch size for eval. + :param device: PyTorch device where the model will be allocated. + :param plugins: (optional) list of StrategyPlugins. + :param evaluator: (optional) instance of EvaluationPlugin for logging + and metric computations. None to remove logging. + :param eval_every: the frequency of the calls to `eval` inside the + training loop. -1 disables the evaluation. 0 means `eval` is called + only at the end of the learning experience. Values >0 mean that + `eval` is called every `eval_every` experiences and at the end of + the learning experience. + :param peval_mode: one of {'experience', 'iteration'}. Decides whether + the periodic evaluation during training should execute every + `eval_every` experience or iterations (Default='experience'). + """ + super().__init__( + model=model, + optimizer=optimizer, + criterion=criterion, + train_mb_size=train_mb_size, + train_epochs=train_epochs, + eval_mb_size=eval_mb_size, + device=device, + plugins=plugins, + evaluator=evaluator, + eval_every=eval_every, + peval_mode=peval_mode, + **kwargs + ) + if batch_size_mem is None: + self.batch_size_mem = train_mb_size + else: + self.batch_size_mem = batch_size_mem + self.mem_size = mem_size + self.storage_policy = ClassBalancedBufferWithLogits( + self.mem_size, adaptive_size=True + ) + self.replay_loader = None + self.alpha = alpha + self.beta = beta + self.old_model = None + + def _before_training_exp(self, **kwargs): + buffer = self.storage_policy.buffer + if len(buffer) >= self.batch_size_mem: + self.replay_loader = cycle( + torch.utils.data.DataLoader( + buffer, + batch_size=self.batch_size_mem, + shuffle=True, + drop_last=True, + num_workers=kwargs.get("num_workers", 0), + ) + ) + else: + self.replay_loader = None + # Freeze model + self.old_model = deepcopy(self.model) + self.old_model.eval() + + super()._before_training_exp(**kwargs) + + def _after_training_exp(self, **kwargs): + self.replay_loader = None # Allow SER to be checkpointed + self.storage_policy.update(self, **kwargs) + super()._after_training_exp(**kwargs) + + def _before_forward(self, **kwargs): + super()._before_forward(**kwargs) + if self.replay_loader is None: + return None + + batch_x, batch_y, batch_tid, batch_logits = next(self.replay_loader) + batch_x, batch_y, batch_tid, batch_logits = ( + batch_x.to(self.device), + batch_y.to(self.device), + batch_tid.to(self.device), + batch_logits.to(self.device), + ) + self.mbatch[0] = torch.cat((batch_x, self.mbatch[0])) + self.mbatch[1] = torch.cat((batch_y, self.mbatch[1])) + self.mbatch[2] = torch.cat((batch_tid, self.mbatch[2])) + self.batch_logits = batch_logits + + def training_epoch(self, **kwargs): + """Training epoch. + + :param kwargs: + :return: + """ + for self.mbatch in self.dataloader: + if self._stop_training: + break + + self._unpack_minibatch() + self._before_training_iteration(**kwargs) + + self.optimizer.zero_grad() + self.loss = self._make_empty_loss() + + # Forward + self._before_forward(**kwargs) + self.mb_output = self.forward() + with torch.no_grad(): + old_mb_output = avalanche_forward( + self.old_model, self.mb_x, self.mb_task_id + ) + self._after_forward(**kwargs) + + if self.replay_loader is not None: + + # Classification loss on current task and memory data + self.loss += F.cross_entropy( + self.mb_output[:], + self.mb_y[:], + ) + # Backward consistency loss on memory data + self.loss += self.alpha * F.mse_loss( + self.mb_output[: self.batch_size_mem], + self.batch_logits, + ) + # Forward consistency loss on current task data + self.loss += self.beta * F.mse_loss( + old_mb_output[self.batch_size_mem :], + self.mb_output[self.batch_size_mem :], + ) + + else: + self.loss += self.criterion() + + self._before_backward(**kwargs) + self.backward() + self._after_backward(**kwargs) + + # Optimization step + self._before_update(**kwargs) + self.optimizer_step() + self._after_update(**kwargs) + + self._after_training_iteration(**kwargs) + + +__all__ = ["SER"] diff --git a/tests/training/test_strategies.py b/tests/training/test_strategies.py index eed4bc4d5..281eae778 100644 --- a/tests/training/test_strategies.py +++ b/tests/training/test_strategies.py @@ -49,6 +49,7 @@ ER_ACE, ER_AML, DER, + SER, LearningToPrompt, ExpertGateStrategy, MER, @@ -1113,6 +1114,23 @@ def test_der(self): ) run_strategy(benchmark, strategy) + def test_ser(self): + # SIT scenario + model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False) + with self.assertWarns(PositionalArgumentsDeprecatedWarning): + strategy = SER( + model, + optimizer, + criterion, + mem_size=1000, + batch_size_mem=10, + train_mb_size=10, + device=self.device, + eval_mb_size=50, + train_epochs=2, + ) + run_strategy(benchmark, strategy) + def test_feature_distillation(self): # SIT scenario model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)