Skip to content

Commit dac6aaa

Browse files
committed
Enable pickling of ExpData
1 parent b8307a1 commit dac6aaa

File tree

5 files changed

+78
-0
lines changed

5 files changed

+78
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
6565
multi-processing contexts. This works only if the amici-generated model
6666
package exists in the same file system location and does not change until
6767
unpickling.
68+
* `amici.ExpData` is now picklable.
6869

6970
## v0.X Series
7071

python/sdist/amici/swig_wrappers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,26 @@ def file_checksum(
397397
for chunk in iter(lambda: f.read(chunk_size), b""):
398398
h.update(chunk)
399399
return h.hexdigest()
400+
401+
402+
def restore_edata(
403+
init_args: Sequence,
404+
simulation_parameter_dict: dict[str, Any],
405+
) -> amici_swig.ExpData:
406+
"""
407+
Recreate an ExpData instance.
408+
409+
For use in ExpData.__reduce__.
410+
"""
411+
edata = amici_swig.ExpData(*init_args)
412+
413+
edata.pscale = amici.parameter_scaling_from_int_vector(
414+
simulation_parameter_dict.pop("pscale")
415+
)
416+
for key, value in simulation_parameter_dict.items():
417+
if key == "timepoints":
418+
# timepoints are set during ExpData construction
419+
continue
420+
assert hasattr(edata, key)
421+
setattr(edata, key, value)
422+
return edata

python/tests/test_swig_interface.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,19 @@ def test_pickle_model(sbml_example_presimulation_module):
697697
model.get_steady_state_sensitivity_mode()
698698
!= model_pickled.get_steady_state_sensitivity_mode()
699699
)
700+
701+
702+
def test_pickle_edata():
703+
ny = 2
704+
nz = 3
705+
ne = 4
706+
nt = 5
707+
edata = amici.ExpData(ny, nz, ne, range(nt))
708+
edata.set_observed_data(list(np.arange(ny * nt, dtype=float)))
709+
edata.pscale = amici.parameter_scaling_from_int_vector(
710+
[amici.ParameterScaling.log10] * 5
711+
)
712+
import pickle
713+
714+
edata_pickled = pickle.loads(pickle.dumps(edata))
715+
assert edata == edata_pickled

swig/amici.i

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,23 @@ wrap_unique_ptr(ExpDataPtr, amici::ExpData)
155155
%naturalvar amici::SimulationParameters::reinitialization_state_idxs_sim;
156156
%naturalvar amici::SimulationParameters::reinitialization_state_idxs_presim;
157157

158+
%extend amici::SimulationParameters {
159+
%pythoncode %{
160+
def __iter__(self):
161+
for attr_name in dir(self):
162+
if (
163+
not attr_name.startswith('_')
164+
and attr_name not in ("this", "thisown")
165+
and not callable(attr_val := getattr(self, attr_name))
166+
):
167+
if isinstance(attr_val, (DoubleVector, ParameterScalingVector)):
168+
yield attr_name, tuple(attr_val)
169+
else:
170+
yield attr_name, attr_val
171+
%}
172+
}
173+
174+
158175
// DO NOT IGNORE amici::SimulationParameters, amici::ModelDimensions, amici::CpuTimer
159176
%ignore amici::ModelContext;
160177
%ignore amici::ContextManager;

swig/edata.i

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,27 @@ def __deepcopy__(self, memo):
102102
# invoke copy constructor
103103
return type(self)(self)
104104

105+
def __reduce__(self):
106+
from amici.swig_wrappers import restore_edata
107+
108+
return (
109+
restore_edata,
110+
(
111+
# ExpData ctor arguments
112+
(
113+
self.nytrue(),
114+
self.nztrue(),
115+
self.nmaxevent(),
116+
self.get_timepoints(),
117+
self.get_observed_data(),
118+
self.get_observed_data_std_dev(),
119+
self.get_observed_events(),
120+
self.get_observed_events_std_dev(),
121+
),
122+
dict(self)
123+
),
124+
{}
125+
)
105126
%}
106127
};
107128
%extend std::unique_ptr<amici::ExpData> {

0 commit comments

Comments
 (0)