diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0a9b26931b..593ab3d64b 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -409,7 +409,7 @@ def _extra_metadata_from_folder(self, folder): for segment_index, rs in enumerate(self.segments): time_file = folder / f"times_cached_seg{segment_index}.npy" if time_file.is_file(): - time_vector = np.load(time_file) + time_vector = np.load(time_file, mmap_mode="r") rs.time_vector = time_vector def _extra_metadata_to_folder(self, folder): diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index a1e5aa47bf..643426d6c7 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -375,6 +375,31 @@ def test_save_and_load_time_shift(self, request, fixture_name, tmp_path): times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx) ) + @pytest.mark.parametrize("save_format", ["binary", "zarr"]) + def test_shift_times_after_load(self, request, save_format, tmp_path): + """ + Shift times on a recording loaded from disk as a read-only np.memmap + (binary folder) and a lazy zarr.Array (zarr). Neither supports an in-place + `+=`, so `shift_times` must shift a writable copy. + """ + _, times_recording, all_times = self._get_fixture_data(request, "time_vector_recording") + + folder = tmp_path / "rec" + times_recording.save(format=save_format, folder=folder) + load_path = folder.with_suffix(".zarr") if save_format == "zarr" else folder + loaded = si.load(load_path) + + # Confirm we are actually exercising a non-writeable / non-ndarray path. + for idx in range(loaded.get_num_segments()): + tv = loaded.segments[idx].time_vector + assert not (isinstance(tv, np.ndarray) and tv.flags.writeable) + + shift = 123.456 + loaded.shift_times(shift) + + for idx in range(loaded.get_num_segments()): + assert np.allclose(loaded.get_times(segment_index=idx), all_times[idx] + shift, rtol=0, atol=1e-8) + def _store_all_times(self, recording): """ Convenience function to store original times of all segments to a dict. diff --git a/src/spikeinterface/core/time_series.py b/src/spikeinterface/core/time_series.py index d4d4717dff..7dd19230c6 100644 --- a/src/spikeinterface/core/time_series.py +++ b/src/spikeinterface/core/time_series.py @@ -1,11 +1,26 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, TYPE_CHECKING, TypeAlias import warnings import numpy as np from spikeinterface.core.base import BaseExtractor, BaseSegment +if TYPE_CHECKING: + import zarr + +# A recording segment's time vector: a 1-D array of per-sample times (in seconds). +# The backing store depends on how the recording was created/loaded: +# - np.ndarray : set_times() (writeable, in-memory) +# - np.memmap : BinaryFolderRecording load via np.load(..., mmap_mode="r") +# -- *read-only* ; see BaseRecording._extra_metadata_from_folder +# - zarr.Array : ZarrRecordingExtractor load +# -- *read-only* ; see ZarrRecordingExtractor.__init__ +# Code reading `.time_vector` must not assume it is writeable (see `shift_times`). +TimeVector: TypeAlias = "np.ndarray | zarr.Array" # np.memmap is an np.ndarray subclass + class TimeSeries(ABC): """ @@ -146,8 +161,8 @@ def get_times(self, segment_index=None, start_frame=None, end_frame=None) -> np. Returns ------- - np.array - The 1d times array + np.ndarray + The 1d times array. If the times were mem-mapped, loads them into memory. """ segment_index = self._check_segment_index(segment_index) rs = self.segments[segment_index] @@ -211,8 +226,9 @@ def set_times(self, times, segment_index=None, with_warning=True): Parameters ---------- - times : 1d np.array - The time vector + times : 1d array-like + The time vector. Lazy/read-only input (e.g. memmap or zarr.Array) is loaded + into memory and cast to float64 before being stored on the segment. segment_index : int or None, default: None The segment index (required for multi-segment) with_warning : bool, default: True @@ -273,7 +289,12 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N rs = self.segments[segment_index] if self.has_time_vector(segment_index=segment_index): - rs.time_vector += shift + if isinstance(rs.time_vector, np.ndarray) and rs.time_vector.flags.writeable: + # If this is an in-memory numpy array + rs.time_vector += shift # in-place, no copy + else: + # If this is a read-only memmap or zarr.Array + rs.time_vector = np.asarray(rs.time_vector) + shift else: new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift rs.t_start = new_start_time @@ -366,7 +387,23 @@ class TimeSeriesSegment(BaseSegment): """Per-segment time-series class. Provides time handling methods (sample/time conversion, start/end time, time vectors) on top of ``BaseSegment``.""" - def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): + def __init__( + self, + sampling_frequency: float | None = None, + t_start: float | None = None, + time_vector: TimeVector | None = None, + ) -> None: + """ + Parameters + ---------- + sampling_frequency : float | None, default: None + Sampling frequency in Hz. Mutually exclusive with `time_vector`. + t_start : float | None, default: None + Start time (s) used when times are regular (no `time_vector`). + time_vector : TimeVector | None, default: None + Explicit per-sample times. May be a writeable np.ndarray, a read-only + np.memmap, or a lazy zarr.Array. + """ # sampling_frequency and time_vector are exclusive if sampling_frequency is None: assert time_vector is not None, "Pass either 'sampling_frequency' or 'time_vector'" @@ -377,7 +414,7 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None): self.sampling_frequency = sampling_frequency self.t_start = t_start - self.time_vector = time_vector + self.time_vector: TimeVector | None = time_vector BaseSegment.__init__(self)