diff --git a/pytorch_pfn_extras/training/extensions/_snapshot.py b/pytorch_pfn_extras/training/extensions/_snapshot.py index ce6dadcc1..b94808a6d 100644 --- a/pytorch_pfn_extras/training/extensions/_snapshot.py +++ b/pytorch_pfn_extras/training/extensions/_snapshot.py @@ -444,9 +444,8 @@ def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None: filename = filename(manager) else: filename = filename.format(manager) - outdir = manager.out writer( # type: ignore - filename, outdir, serialized_target, savefun=self._savefun + filename, serialized_target, savefun=self._savefun ) def finalize(self, manager: ExtensionsManagerProtocol) -> None: diff --git a/pytorch_pfn_extras/training/extensions/log_report.py b/pytorch_pfn_extras/training/extensions/log_report.py index c78448d0e..ce653219f 100644 --- a/pytorch_pfn_extras/training/extensions/log_report.py +++ b/pytorch_pfn_extras/training/extensions/log_report.py @@ -223,11 +223,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # write to the log file log_name = self._filename.format(**stats_cpu) - out = manager.out savefun = LogWriterSaveFunc(self._format, self._append) writer( log_name, - out, self._log_looker.get(), savefun=savefun, append=self._append, diff --git a/pytorch_pfn_extras/training/extensions/profile_report.py b/pytorch_pfn_extras/training/extensions/profile_report.py index 95514d275..00d008119 100644 --- a/pytorch_pfn_extras/training/extensions/profile_report.py +++ b/pytorch_pfn_extras/training/extensions/profile_report.py @@ -131,7 +131,6 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: ) writer( log_name, - out, self._log, # type: ignore savefun=savefun, append=self._append, diff --git a/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py b/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py index a54d0446f..5722aafe9 100644 --- a/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py +++ b/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py @@ -453,7 +453,6 @@ def save_plot_using_module( writer( self._filename, - manager.out, (fig, plt), # type: ignore savefun=matplotlib_savefun, ) diff --git a/pytorch_pfn_extras/writing/_parallel_writer.py b/pytorch_pfn_extras/writing/_parallel_writer.py index 2c784e15a..a0b53cdd6 100644 --- a/pytorch_pfn_extras/writing/_parallel_writer.py +++ b/pytorch_pfn_extras/writing/_parallel_writer.py @@ -34,16 +34,13 @@ def __init__( def _save_with_exitcode( self, filename: str, - out_dir: str, target: _TargetType, savefun: _SaveFun, append: bool, **savefun_kwargs: Any, ) -> None: try: - self.save( - filename, out_dir, target, savefun, append, **savefun_kwargs - ) + self.save(filename, target, savefun, append, **savefun_kwargs) except Exception as e: thread = threading.current_thread() thread.exitcode = -1 # type: ignore[attr-defined] @@ -56,7 +53,6 @@ def _save_with_exitcode( def create_worker( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, @@ -65,7 +61,7 @@ def create_worker( ) -> threading.Thread: return threading.Thread( target=self._save_with_exitcode, - args=(filename, out_dir, target, savefun, append), + args=(filename, target, savefun, append), kwargs=savefun_kwargs, ) @@ -97,7 +93,6 @@ def __init__( def create_worker( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, @@ -106,6 +101,6 @@ def create_worker( ) -> multiprocessing.Process: return multiprocessing.Process( target=self.save, - args=(filename, out_dir, target, savefun, append), + args=(filename, target, savefun, append), kwargs=savefun_kwargs, ) diff --git a/pytorch_pfn_extras/writing/_queue_writer.py b/pytorch_pfn_extras/writing/_queue_writer.py index b4348570b..bc66e5c41 100644 --- a/pytorch_pfn_extras/writing/_queue_writer.py +++ b/pytorch_pfn_extras/writing/_queue_writer.py @@ -14,9 +14,7 @@ _Worker, ) -_QueUnit = Optional[ - Tuple[_TaskFun, str, str, _TargetType, Optional[_SaveFun], bool] -] +_QueUnit = Optional[Tuple[_TaskFun, str, _TargetType, Optional[_SaveFun], bool]] class QueueWriter(Writer, Generic[_Worker]): @@ -66,16 +64,13 @@ def __init__( def __call__( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, append: bool = False, ) -> None: assert not self._finalized - self._queue.put( - (self._task, filename, out_dir, target, savefun, append) - ) + self._queue.put((self._task, filename, target, savefun, append)) def create_task(self, savefun: _SaveFun) -> _TaskFun: return SimpleWriter(savefun=savefun) @@ -93,9 +88,7 @@ def consume(self, q: "queue.Queue[_QueUnit]") -> None: q.task_done() return else: - task[0]( - task[1], task[2], task[3], savefun=task[4], append=task[5] - ) + task[0](task[1], task[2], savefun=task[3], append=task[4]) q.task_done() def finalize(self) -> None: diff --git a/pytorch_pfn_extras/writing/_simple_writer.py b/pytorch_pfn_extras/writing/_simple_writer.py index b26db45a7..98cf0e425 100644 --- a/pytorch_pfn_extras/writing/_simple_writer.py +++ b/pytorch_pfn_extras/writing/_simple_writer.py @@ -44,7 +44,6 @@ def __init__( def __call__( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, @@ -52,4 +51,4 @@ def __call__( ) -> None: if savefun is None: savefun = self._savefun - self.save(filename, out_dir, target, savefun, append, **self._kwds) + self.save(filename, target, savefun, append, **self._kwds) diff --git a/pytorch_pfn_extras/writing/_tensorboard_writer.py b/pytorch_pfn_extras/writing/_tensorboard_writer.py index 8e004e972..47edc894f 100644 --- a/pytorch_pfn_extras/writing/_tensorboard_writer.py +++ b/pytorch_pfn_extras/writing/_tensorboard_writer.py @@ -51,7 +51,6 @@ def __del__(self) -> None: def __call__( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, diff --git a/pytorch_pfn_extras/writing/_writer_base.py b/pytorch_pfn_extras/writing/_writer_base.py index 10f96c58d..e0ed214b1 100644 --- a/pytorch_pfn_extras/writing/_writer_base.py +++ b/pytorch_pfn_extras/writing/_writer_base.py @@ -235,7 +235,6 @@ def __init__( def __call__( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, @@ -280,7 +279,6 @@ def finalize(self) -> None: def save( self, filename: str, - out_dir: str, target: _TargetType, savefun: _SaveFun, append: bool, @@ -375,7 +373,6 @@ def __init__( def __call__( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, @@ -389,7 +386,6 @@ def __call__( self._filename = filename self._worker = self.create_worker( filename, - out_dir, target, savefun=savefun, append=append, @@ -402,7 +398,6 @@ def __call__( def create_worker( self, filename: str, - out_dir: str, target: _TargetType, *, savefun: Optional[_SaveFun] = None, diff --git a/tests/pytorch_pfn_extras_tests/test_writing.py b/tests/pytorch_pfn_extras_tests/test_writing.py index b8ea93274..931ebd1ed 100644 --- a/tests/pytorch_pfn_extras_tests/test_writing.py +++ b/tests/pytorch_pfn_extras_tests/test_writing.py @@ -15,7 +15,7 @@ def test_tensorboard_writing(): writer = ppe.writing.TensorBoardWriter( out_dir=tempd, filename_suffix="_test" ) - writer(None, None, data) + writer(None, data) # Check that the file was generated for snap in os.listdir(tempd): assert "_test" in snap diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py index c6578b111..139a1b276 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py @@ -1,4 +1,5 @@ import multiprocessing +import os import tempfile import threading from unittest import mock @@ -11,10 +12,12 @@ def test_simple_writer(): target = mock.MagicMock() - w = writing.SimpleWriter(foo=True) savefun = mock.MagicMock() with tempfile.TemporaryDirectory() as tempd: - w("myfile.dat", tempd, target, savefun=savefun) + w = writing.SimpleWriter(foo=True, out_dir=tempd) + filename = "myfile.dat" + w(filename, target, savefun=savefun) + assert os.path.exists(os.path.join(tempd, filename)) assert savefun.call_count == 1 assert savefun.call_args[0][0] == target assert savefun.call_args[1]["foo"] is True @@ -22,14 +25,14 @@ def test_simple_writer(): def test_standard_writer(): target = mock.MagicMock() - w = writing.StandardWriter() worker = mock.MagicMock() worker.exitcode = 0 name = spshot_writers_path + ".StandardWriter.create_worker" with mock.patch(name, return_value=worker): with tempfile.TemporaryDirectory() as tempd: - w("myfile.dat", tempd, target) - w("myfile.dat", tempd, target) + w = writing.StandardWriter(out_dir=tempd) + w("myfile.dat", target) + w("myfile.dat", target) w.finalize() assert worker.start.call_count == 2 @@ -38,36 +41,36 @@ def test_standard_writer(): def test_thread_writer_create_worker(): target = mock.MagicMock() - w = writing.ThreadWriter() with tempfile.TemporaryDirectory() as tempd: - worker = w.create_worker("myfile.dat", tempd, target, append=False) + w = writing.ThreadWriter(out_dir=tempd) + worker = w.create_worker("myfile.dat", target, append=False) assert isinstance(worker, threading.Thread) - w("myfile2.dat", tempd, "test") + w("myfile2.dat", "test") w.finalize() def test_thread_writer_fail(): - w = writing.ThreadWriter(savefun=None) with tempfile.TemporaryDirectory() as tempd: - w("myfile2.dat", tempd, "test") + w = writing.ThreadWriter(savefun=None, out_dir=tempd) + w("myfile2.dat", "test") with pytest.raises(RuntimeError): w.finalize() def test_process_writer_create_worker(): target = mock.MagicMock() - w = writing.ProcessWriter() with tempfile.TemporaryDirectory() as tempd: - worker = w.create_worker("myfile.dat", tempd, target, append=False) + w = writing.ProcessWriter(out_dir=tempd) + worker = w.create_worker("myfile.dat", target, append=False) assert isinstance(worker, multiprocessing.Process) - w("myfile2.dat", tempd, "test") + w("myfile2.dat", "test") w.finalize() def test_process_writer_fail(): - w = writing.ProcessWriter(savefun=None) with tempfile.TemporaryDirectory() as tempd: - w("myfile2.dat", tempd, "test") + w = writing.ProcessWriter(savefun=None, out_dir=tempd) + w("myfile2.dat", "test") with pytest.raises(RuntimeError): w.finalize() @@ -82,11 +85,10 @@ def test_queue_writer(): ] with mock.patch(names[0], return_value=q): with mock.patch(names[1], return_value=consumer): - w = writing.QueueWriter() - with tempfile.TemporaryDirectory() as tempd: - w("myfile.dat", tempd, target) - w("myfile.dat", tempd, target) + w = writing.QueueWriter(out_dir=tempd) + w("myfile.dat", target) + w("myfile.dat", target) w.finalize() assert consumer.start.call_count == 1 diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py b/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py index da469f5c7..5708a50a5 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py @@ -527,7 +527,11 @@ def test_extensions_accessing_models_without_flag(priority): if priority is not None: extension.priority = priority manager = training.ExtensionsManager( - m, optimizer, 1, iters_per_epoch=5, extensions=[extension] + m, + optimizer, + 1, + iters_per_epoch=5, + extensions=[extension], ) while not manager.stop_trigger: with pytest.raises(RuntimeError):