Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _extra_metadata_from_folder(self, folder):
for segment_index, rs in enumerate(self.segments):
time_file = folder / f"times_cached_seg{segment_index}.npy"
if time_file.is_file():
time_vector = np.load(time_file)
time_vector = np.load(time_file, mmap_mode="r")
rs.time_vector = time_vector

def _extra_metadata_to_folder(self, folder):
Expand Down
25 changes: 25 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,31 @@ def test_save_and_load_time_shift(self, request, fixture_name, tmp_path):
times_recording.get_times(segment_index=idx), loaded_recording.get_times(segment_index=idx)
)

@pytest.mark.parametrize("save_format", ["binary", "zarr"])
def test_shift_times_after_load(self, request, save_format, tmp_path):
"""
Shift times on a recording loaded from disk as a read-only np.memmap
(binary folder) and a lazy zarr.Array (zarr). Neither supports an in-place
`+=`, so `shift_times` must shift a writable copy.
"""
_, times_recording, all_times = self._get_fixture_data(request, "time_vector_recording")

folder = tmp_path / "rec"
times_recording.save(format=save_format, folder=folder)
load_path = folder.with_suffix(".zarr") if save_format == "zarr" else folder
loaded = si.load(load_path)

# Confirm we are actually exercising a non-writeable / non-ndarray path.
for idx in range(loaded.get_num_segments()):
tv = loaded.segments[idx].time_vector
assert not (isinstance(tv, np.ndarray) and tv.flags.writeable)

shift = 123.456
loaded.shift_times(shift)

for idx in range(loaded.get_num_segments()):
assert np.allclose(loaded.get_times(segment_index=idx), all_times[idx] + shift, rtol=0, atol=1e-8)

def _store_all_times(self, recording):
"""
Convenience function to store original times of all segments to a dict.
Expand Down
53 changes: 45 additions & 8 deletions src/spikeinterface/core/time_series.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, TYPE_CHECKING, TypeAlias
import warnings

import numpy as np

from spikeinterface.core.base import BaseExtractor, BaseSegment

if TYPE_CHECKING:
import zarr

# A recording segment's time vector: a 1-D array of per-sample times (in seconds).
# The backing store depends on how the recording was created/loaded:
# - np.ndarray : set_times() (writeable, in-memory)
# - np.memmap : BinaryFolderRecording load via np.load(..., mmap_mode="r")
# -- *read-only* ; see BaseRecording._extra_metadata_from_folder
# - zarr.Array : ZarrRecordingExtractor load
# -- *read-only* ; see ZarrRecordingExtractor.__init__
# Code reading `.time_vector` must not assume it is writeable (see `shift_times`).
TimeVector: TypeAlias = "np.ndarray | zarr.Array" # np.memmap is an np.ndarray subclass


class TimeSeries(ABC):
"""
Expand Down Expand Up @@ -146,8 +161,8 @@ def get_times(self, segment_index=None, start_frame=None, end_frame=None) -> np.

Returns
-------
np.array
The 1d times array
np.ndarray
The 1d times array. If the times were mem-mapped, loads them into memory.
"""
segment_index = self._check_segment_index(segment_index)
rs = self.segments[segment_index]
Expand Down Expand Up @@ -211,8 +226,9 @@ def set_times(self, times, segment_index=None, with_warning=True):

Parameters
----------
times : 1d np.array
The time vector
times : 1d array-like
The time vector. Lazy/read-only input (e.g. memmap or zarr.Array) is loaded
into memory and cast to float64 before being stored on the segment.
segment_index : int or None, default: None
The segment index (required for multi-segment)
with_warning : bool, default: True
Expand Down Expand Up @@ -273,7 +289,12 @@ def shift_times(self, shift: int | float, segment_index: int | None = None) -> N
rs = self.segments[segment_index]

if self.has_time_vector(segment_index=segment_index):
rs.time_vector += shift
if isinstance(rs.time_vector, np.ndarray) and rs.time_vector.flags.writeable:
# If this is an in-memory numpy array
rs.time_vector += shift # in-place, no copy
else:
# If this is a read-only memmap or zarr.Array
rs.time_vector = np.asarray(rs.time_vector) + shift
else:
new_start_time = 0 + shift if rs.t_start is None else rs.t_start + shift
rs.t_start = new_start_time
Expand Down Expand Up @@ -366,7 +387,23 @@ class TimeSeriesSegment(BaseSegment):
"""Per-segment time-series class. Provides time handling methods (sample/time conversion,
start/end time, time vectors) on top of ``BaseSegment``."""

def __init__(self, sampling_frequency=None, t_start=None, time_vector=None):
def __init__(
self,
sampling_frequency: float | None = None,
t_start: float | None = None,
time_vector: TimeVector | None = None,
) -> None:
"""
Parameters
----------
sampling_frequency : float | None, default: None
Sampling frequency in Hz. Mutually exclusive with `time_vector`.
t_start : float | None, default: None
Start time (s) used when times are regular (no `time_vector`).
time_vector : TimeVector | None, default: None
Explicit per-sample times. May be a writeable np.ndarray, a read-only
np.memmap, or a lazy zarr.Array.
"""
# sampling_frequency and time_vector are exclusive
if sampling_frequency is None:
assert time_vector is not None, "Pass either 'sampling_frequency' or 'time_vector'"
Expand All @@ -377,7 +414,7 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None):

self.sampling_frequency = sampling_frequency
self.t_start = t_start
self.time_vector = time_vector
self.time_vector: TimeVector | None = time_vector

BaseSegment.__init__(self)

Expand Down
Loading