diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 94b5e175bf88a2..3ef30fc8ae5528 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -849,29 +849,29 @@ def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=No ): self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.") - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.use_cache = False - config.return_dict = True - model = model_class(config) + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + model = model_class(config) - model.to(torch_device) - model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) - model.train() + model.to(torch_device) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + model.train() - # unfreeze additional layers - for p in model.parameters(): - p.requires_grad_(True) + # unfreeze additional layers + for p in model.parameters(): + p.requires_grad_(True) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - loss = model(**inputs).loss - loss.backward() - optimizer.step() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + optimizer.step() - for k, v in model.named_parameters(): - if v.requires_grad: - self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") + for k, v in model.named_parameters(): + if v.requires_grad: + self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!") def test_training(self): if not self.model_tester.is_training: