diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aef0134dd..91a7905c98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,11 +61,19 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni This is a wrapper for both `amici.run_simulation` and `amici.run_simulations`, depending on the type of the `edata` argument. It also supports passing some `Solver` options as keyword arguments. -* `amici.ModelPtr` now supports sufficient pickling for use in - multi-processing contexts. This works only if the amici-generated model - package exists in the same file system location and does not change until - unpickling. -* `amici.ExpData` is now picklable. +* Improved `pickle` support for `amici.{ModelPtr,Solver,ExpData`. + Note that AMICI's pickling support is only intended for short-term storage + or inter-process communication. + Reading pickled objects after updating AMICI or the model code will almost + certainly fail. + * `amici.ModelPtr` now supports sufficient pickling for use in + multi-processing contexts. This works only if the amici-generated model + package exists in the same file system location and does not change until + unpickling. + * `amici.Solver` is now picklable if amici was built with HDF5 support. + This only works on shared file systems, as the solver state is stored in a + temporary HDF5 file. + * `amici.ExpData` is now picklable. ## v0.X Series diff --git a/include/amici/solver.h b/include/amici/solver.h index b8181e4633..6b52a0766f 100644 --- a/include/amici/solver.h +++ b/include/amici/solver.h @@ -75,6 +75,12 @@ class Solver { */ SUNContext get_sun_context() const; + /** + * @brief Get the name of this class. + * @return Class name. + */ + virtual std::string get_class_name() const = 0; + /** * @brief runs a forward simulation until the specified timepoint * diff --git a/include/amici/solver_cvodes.h b/include/amici/solver_cvodes.h index 4f454e970e..5216fd3ec3 100644 --- a/include/amici/solver_cvodes.h +++ b/include/amici/solver_cvodes.h @@ -38,6 +38,8 @@ class CVodeSolver : public Solver { */ Solver* clone() const override; + std::string get_class_name() const override {return "CVodeSolver"; }; + void reinit( realtype t0, AmiVector const& yy0, AmiVector const& yp0 ) const override; diff --git a/include/amici/solver_idas.h b/include/amici/solver_idas.h index ed56e174cd..7b57edac86 100644 --- a/include/amici/solver_idas.h +++ b/include/amici/solver_idas.h @@ -35,6 +35,8 @@ class IDASolver : public Solver { */ Solver* clone() const override; + std::string get_class_name() const override {return "IDASolver"; }; + void reinit_post_process_f(realtype tnext) const override; void reinit_post_process_b(realtype tnext) const override; diff --git a/python/sdist/amici/swig_wrappers.py b/python/sdist/amici/swig_wrappers.py index 53092fb806..5fa0fc5962 100644 --- a/python/sdist/amici/swig_wrappers.py +++ b/python/sdist/amici/swig_wrappers.py @@ -420,3 +420,29 @@ def restore_edata( assert hasattr(edata, key) setattr(edata, key, value) return edata + + +def restore_solver(cls: type, cls_name: str, hdf5_file: str) -> Solver: + """ + Recreate a Solver or SolverPtr instance from an HDF5 file. + + For use in Solver.__reduce__. + + :param cls: + Class of the original object ({CVode,IDA}Solver or SolverPtr). + :param cls_name: + Name of the (pointed to) solver class ("CVodeSolver" or "IDASolver"). + :param hdf5_file: + HDF5 file from which to read the solver settings. + """ + if cls_name == "CVodeSolver": + solver = amici.CVodeSolver() + elif cls_name == "IDASolver": + solver = amici.IDASolver() + else: + raise ValueError(f"Unknown solver class name: {cls_name}") + + if not issubclass(cls, Solver): + solver = cls(solver) + read_solver_settings_from_hdf5(hdf5_file, solver) + return solver diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index 783fed7d9d..dedd6bf165 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -712,3 +712,25 @@ def test_pickle_edata(): edata_pickled = pickle.loads(pickle.dumps(edata)) assert edata == edata_pickled + + +@pytest.mark.skipif( + not amici.hdf5_enabled, + reason="AMICI build without HDF5 support", +) +def test_pickle_solver(): + for solver in ( + amici.CVodeSolver(), + amici.IDASolver(), + amici.SolverPtr(amici.CVodeSolver()), + amici.SolverPtr(amici.IDASolver()), + ): + solver.set_max_steps(1234) + solver.set_sensitivity_order(amici.SensitivityOrder.first) + solver_pickled = pickle.loads(pickle.dumps(solver)) + assert type(solver) is type(solver_pickled) + assert solver.get_max_steps() == solver_pickled.get_max_steps() + assert ( + solver.get_sensitivity_order() + == solver_pickled.get_sensitivity_order() + ) diff --git a/swig/solver.i b/swig/solver.i index cec99e31f7..1d7672f7fd 100644 --- a/swig/solver.i +++ b/swig/solver.i @@ -101,17 +101,46 @@ def _solver_repr(self: "Solver"): f" state_ordering: {self.get_state_ordering()}", ">" ]) + +def _solver_reduce(self: "Solver"): + """ + For now, we just store solver settings in a temporary HDF5 file. + This is sufficient for multiprocessing use cases, but will not survive + reboots and will not work in distributed (MPI) settings. + This requires that amici was compiled with HDF5 support. + """ + from amici.swig_wrappers import restore_solver + from tempfile import NamedTemporaryFile + from amici import write_solver_settings_to_hdf5 + import os + with NamedTemporaryFile(suffix=".h5", delete=False) as tmpfile: + tmpfilename = tmpfile.name + write_solver_settings_to_hdf5(self, tmpfilename) + + return ( + restore_solver, + (self.__class__, self.get_class_name(), tmpfilename,), + ) + %} %extend amici::CVodeSolver { %pythoncode %{ def __repr__(self): return _solver_repr(self) + +def __reduce__(self): + return _solver_reduce(self) + %} }; %extend amici::IDASolver { %pythoncode %{ def __repr__(self): return _solver_repr(self) + +def __reduce__(self): + return _solver_reduce(self) + %} }; @@ -122,6 +151,9 @@ def __repr__(self): def __deepcopy__(self, memo): return self.clone() + +def __reduce__(self): + return _solver_reduce(self) %} }; @@ -129,6 +161,10 @@ def __deepcopy__(self, memo): %pythoncode %{ def __deepcopy__(self, memo): return self.clone() + +def __reduce__(self): + return _solver_reduce(self) + %} };