Skip to content

Commit f6d1a29

Browse files
committed
Enable pickling of Solver and SolverPtr
Enable pickling of `Solver` and `SolverPtr`. For now, only with limited functionality and through storing solver settings in HDF5 files.
1 parent 7e7557e commit f6d1a29

File tree

7 files changed

+95
-0
lines changed

7 files changed

+95
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
6666
package exists in the same file system location and does not change until
6767
unpickling.
6868
* `amici.ExpData` is now picklable.
69+
* `amici.Solver` is now picklable if amici was built with HDF5 support.
6970

7071
## v0.X Series
7172

include/amici/solver.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class Solver {
7575
*/
7676
SUNContext get_sun_context() const;
7777

78+
/**
79+
* @brief Get the name of this class.
80+
* @return Class name.
81+
*/
82+
virtual std::string get_class_name() const = 0;
83+
7884
/**
7985
* @brief runs a forward simulation until the specified timepoint
8086
*

include/amici/solver_cvodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class CVodeSolver : public Solver {
3838
*/
3939
Solver* clone() const override;
4040

41+
std::string get_class_name() const override {return "CVodeSolver"; };
42+
4143
void reinit(
4244
realtype t0, AmiVector const& yy0, AmiVector const& yp0
4345
) const override;

include/amici/solver_idas.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class IDASolver : public Solver {
3535
*/
3636
Solver* clone() const override;
3737

38+
std::string get_class_name() const override {return "IDASolver"; };
39+
3840
void reinit_post_process_f(realtype tnext) const override;
3941

4042
void reinit_post_process_b(realtype tnext) const override;

python/sdist/amici/swig_wrappers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,3 +420,29 @@ def restore_edata(
420420
assert hasattr(edata, key)
421421
setattr(edata, key, value)
422422
return edata
423+
424+
425+
def restore_solver(cls: type, cls_name: str, hdf5_file: str) -> Solver:
426+
"""
427+
Recreate a Solver or SolverPtr instance from an HDF5 file.
428+
429+
For use in Solver.__reduce__.
430+
431+
:param cls:
432+
Class of the original object ({CVode,IDA}Solver or SolverPtr).
433+
:param cls_name:
434+
Name of the (pointed to) solver class ("CVodeSolver" or "IDASolver").
435+
:param hdf5_file:
436+
HDF5 file from which to read the solver settings.
437+
"""
438+
if cls_name == "CVodeSolver":
439+
solver = amici.CVodeSolver()
440+
elif cls_name == "IDASolver":
441+
solver = amici.IDASolver()
442+
else:
443+
raise ValueError(f"Unknown solver class name: {cls_name}")
444+
445+
if not issubclass(cls, Solver):
446+
solver = cls(solver)
447+
read_solver_settings_from_hdf5(hdf5_file, solver)
448+
return solver

python/tests/test_swig_interface.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,3 +712,25 @@ def test_pickle_edata():
712712

713713
edata_pickled = pickle.loads(pickle.dumps(edata))
714714
assert edata == edata_pickled
715+
716+
717+
@pytest.mark.skipif(
718+
not amici.hdf5_enabled,
719+
reason="AMICI build without HDF5 support",
720+
)
721+
def test_pickle_solver():
722+
for solver in (
723+
amici.CVodeSolver(),
724+
amici.IDASolver(),
725+
amici.SolverPtr(amici.CVodeSolver()),
726+
amici.SolverPtr(amici.IDASolver()),
727+
):
728+
solver.set_max_steps(1234)
729+
solver.set_sensitivity_order(amici.SensitivityOrder.first)
730+
solver_pickled = pickle.loads(pickle.dumps(solver))
731+
assert type(solver) is type(solver_pickled)
732+
assert solver.get_max_steps() == solver_pickled.get_max_steps()
733+
assert (
734+
solver.get_sensitivity_order()
735+
== solver_pickled.get_sensitivity_order()
736+
)

swig/solver.i

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,46 @@ def _solver_repr(self: "Solver"):
101101
f" state_ordering: {self.get_state_ordering()}",
102102
">"
103103
])
104+
105+
def _solver_reduce(self: "Solver"):
106+
"""
107+
For now, we just store solver settings in a temporary HDF5 file.
108+
This is sufficient for multiprocessing use cases, but will not survive
109+
reboots and will not work in distributed (MPI) settings.
110+
This requires that amici was compiled with HDF5 support.
111+
"""
112+
from amici.swig_wrappers import restore_solver
113+
from tempfile import NamedTemporaryFile
114+
from amici import write_solver_settings_to_hdf5
115+
import os
116+
with NamedTemporaryFile(suffix=".h5", delete=False) as tmpfile:
117+
tmpfilename = tmpfile.name
118+
write_solver_settings_to_hdf5(self, tmpfilename)
119+
120+
return (
121+
restore_solver,
122+
(self.__class__, self.get_class_name(), tmpfilename,),
123+
)
124+
104125
%}
105126
%extend amici::CVodeSolver {
106127
%pythoncode %{
107128
def __repr__(self):
108129
return _solver_repr(self)
130+
131+
def __reduce__(self):
132+
return _solver_reduce(self)
133+
109134
%}
110135
};
111136
%extend amici::IDASolver {
112137
%pythoncode %{
113138
def __repr__(self):
114139
return _solver_repr(self)
140+
141+
def __reduce__(self):
142+
return _solver_reduce(self)
143+
115144
%}
116145
};
117146

@@ -122,13 +151,20 @@ def __repr__(self):
122151

123152
def __deepcopy__(self, memo):
124153
return self.clone()
154+
155+
def __reduce__(self):
156+
return _solver_reduce(self)
125157
%}
126158
};
127159

128160
%extend amici::Solver {
129161
%pythoncode %{
130162
def __deepcopy__(self, memo):
131163
return self.clone()
164+
165+
def __reduce__(self):
166+
return _solver_reduce(self)
167+
132168
%}
133169
};
134170

0 commit comments

Comments
 (0)