diff --git a/pytorch_pfn_extras/training/_trainer.py b/pytorch_pfn_extras/training/_trainer.py index fe1c46d1..bd8803ec 100644 --- a/pytorch_pfn_extras/training/_trainer.py +++ b/pytorch_pfn_extras/training/_trainer.py @@ -247,9 +247,15 @@ def run( - :meth:`pytorch_pfn_extras.training._evaluator.Evaluator` """ if train_len is None: - train_len = len(train_loader) # type: ignore[arg-type] + if hasattr(train_loader, "__len__"): + train_len = len(train_loader) # type: ignore[arg-type] + else: + train_len = 1 if eval_len is None and val_loader is not None: - eval_len = len(val_loader) # type: ignore[arg-type] + if hasattr(eval_len, "__len__"): + eval_len = len(val_loader) # type: ignore[arg-type] + else: + eval_len = 1 self._train_len = train_len self._eval_len = eval_len diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py index b65faba5..60e35ae3 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py @@ -948,3 +948,39 @@ def test_create_distributed_evaluator(): with mock.patch.object(dist, "is_initialized", return_value=True): evaluator = engine.create_evaluator(models=model, distributed=True) assert isinstance(evaluator, DistributedEvaluator) + + +def test_trainer_run_with_iterator(): + model = MyModel() + model = MyModelWithLossFn(model) + model = ppe.to(model, "cpu") + optimier = mock.MagicMock(spec=torch.optim.Optimizer) + evaluator = engine.create_evaluator(models=model) + trainer = engine.create_trainer( + models=model, optimizers=optimier, max_epochs=1, evaluator=evaluator + ) + train_iterator = ( + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ) + valid_iterator = ( + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(5) + ) + trainer.run(train_iterator, valid_iterator) + assert trainer._train_len == 1 + assert trainer._eval_len == 1