diff --git a/src/_time_machine.c b/src/_time_machine.c index e45c675b..211f8a42 100644 --- a/src/_time_machine.c +++ b/src/_time_machine.c @@ -511,6 +511,14 @@ _time_machine_patch(PyObject *module, PyObject *unused) Py_RETURN_NONE; PyObject *datetime_module = PyImport_ImportModule("datetime"); + if (datetime_module == NULL) { + return NULL; // Propagate ImportError + } + PyObject *time_module = PyImport_ImportModule("time"); + if (time_module == NULL) { + Py_DECREF(datetime_module); + return NULL; // Propagate ImportError + } PyObject *datetime_class = PyObject_GetAttrString(datetime_module, "datetime"); PyCFunctionObject *datetime_datetime_now = @@ -532,8 +540,6 @@ _time_machine_patch(PyObject *module, PyObject *unused) Py_DECREF(datetime_class); Py_DECREF(datetime_module); - PyObject *time_module = PyImport_ImportModule("time"); - /* time.clock_gettime(), only available on Unix platforms. */ @@ -625,6 +631,14 @@ _time_machine_unpatch(PyObject *module, PyObject *unused) Py_RETURN_NONE; PyObject *datetime_module = PyImport_ImportModule("datetime"); + if (datetime_module == NULL) { + return NULL; // Propagate ImportError + } + PyObject *time_module = PyImport_ImportModule("time"); + if (time_module == NULL) { + Py_DECREF(datetime_module); + return NULL; // Propagate ImportError + } PyObject *datetime_class = PyObject_GetAttrString(datetime_module, "datetime"); PyCFunctionObject *datetime_datetime_now = @@ -646,8 +660,6 @@ _time_machine_unpatch(PyObject *module, PyObject *unused) Py_DECREF(datetime_class); Py_DECREF(datetime_module); - PyObject *time_module = PyImport_ImportModule("time"); - /* time.clock_gettime(), only available on Unix platforms. */ diff --git a/tests/test_time_machine.py b/tests/test_time_machine.py index c58f770f..14eb7fd9 100644 --- a/tests/test_time_machine.py +++ b/tests/test_time_machine.py @@ -827,6 +827,42 @@ def test_uuid1(): # error handling tests +@pytest.mark.parametrize( + "module", + [ + "datetime", + "time", + ], +) +def test_start_import_error(module): + with ( + mock.patch.dict(sys.modules, {module: None}), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + time_machine.travel(EPOCH).start() + + assert excinfo.value.args == (f"import of {module} halted; None in sys.modules",) + + +@pytest.mark.parametrize( + "module", + [ + "datetime", + "time", + ], +) +def test_stop_import_error(module): + traveller = time_machine.travel(EPOCH) + traveller.start() + with ( + mock.patch.dict(sys.modules, {module: None}), + pytest.raises(ModuleNotFoundError) as excinfo, + ): + traveller.stop() + + assert excinfo.value.args == (f"import of {module} halted; None in sys.modules",) + + @pytest.mark.parametrize( "func, args", [