From e7b70b39d1b05fa645181c383b6e6135bf448fe6 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 18 Aug 2025 16:51:53 +0100 Subject: [PATCH] Attempt to fix potential reference leaks Fixes #531 --- src/_time_machine.c | 20 ++++++++++++++++---- tests/test_time_machine.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/_time_machine.c b/src/_time_machine.c index e45c675b..211f8a42 100644 --- a/src/_time_machine.c +++ b/src/_time_machine.c @@ -511,6 +511,14 @@ _time_machine_patch(PyObject *module, PyObject *unused) Py_RETURN_NONE; PyObject *datetime_module = PyImport_ImportModule("datetime"); + if (datetime_module == NULL) { + return NULL; // Propagate ImportError + } + PyObject *time_module = PyImport_ImportModule("time"); + if (time_module == NULL) { + Py_DECREF(datetime_module); + return NULL; // Propagate ImportError + } PyObject *datetime_class = PyObject_GetAttrString(datetime_module, "datetime"); PyCFunctionObject *datetime_datetime_now = @@ -532,8 +540,6 @@ _time_machine_patch(PyObject *module, PyObject *unused) Py_DECREF(datetime_class); Py_DECREF(datetime_module); - PyObject *time_module = PyImport_ImportModule("time"); - /* time.clock_gettime(), only available on Unix platforms. */ @@ -625,6 +631,14 @@ _time_machine_unpatch(PyObject *module, PyObject *unused) Py_RETURN_NONE; PyObject *datetime_module = PyImport_ImportModule("datetime"); + if (datetime_module == NULL) { + return NULL; // Propagate ImportError + } + PyObject *time_module = PyImport_ImportModule("time"); + if (time_module == NULL) { + Py_DECREF(datetime_module); + return NULL; // Propagate ImportError + } PyObject *datetime_class = PyObject_GetAttrString(datetime_module, "datetime"); PyCFunctionObject *datetime_datetime_now = @@ -646,8 +660,6 @@ _time_machine_unpatch(PyObject *module, PyObject *unused) Py_DECREF(datetime_class); Py_DECREF(datetime_module); - PyObject *time_module = PyImport_ImportModule("time"); - /* time.clock_gettime(), only available on Unix platforms. */ diff --git a/tests/test_time_machine.py b/tests/test_time_machine.py index c58f770f..14eb7fd9 100644 --- a/tests/test_time_machine.py +++ b/tests/test_time_machine.py @@ -827,6 +827,42 @@ def test_uuid1(): # error handling tests +@pytest.mark.parametrize( + "module", + [ + "datetime", + "time", + ], +) +def test_start_import_error(module): + with ( + mock.patch.dict(sys.modules, {module: None}), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + time_machine.travel(EPOCH).start() + + assert excinfo.value.args == (f"import of {module} halted; None in sys.modules",) + + +@pytest.mark.parametrize( + "module", + [ + "datetime", + "time", + ], +) +def test_stop_import_error(module): + traveller = time_machine.travel(EPOCH) + traveller.start() + with ( + mock.patch.dict(sys.modules, {module: None}), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + traveller.stop() + + assert excinfo.value.args == (f"import of {module} halted; None in sys.modules",) + + @pytest.mark.parametrize( "func, args", [