Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 353 additions & 8 deletions elephant/test/test_spike_train_surrogates.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,14 +628,109 @@ def test_bin_shuffling_empty_train(self):
n_surrogates=1)[0]
self.assertEqual(np.sum(surrogate_train.to_bool_array()), 0)

def test_bin_shuffling_spike_count_preserved(self):
"""Total spike count must be unchanged after bin shuffling."""
spiketrain = neo.SpikeTrain(
[90, 93, 97, 100, 105, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
bin_size = 5 * pq.ms
max_displacement = 10
binned = conv.BinnedSpikeTrain(spiketrain, bin_size)
original_count = int(np.sum(binned.to_bool_array()))
for surrogate in surr.bin_shuffling(
binned, max_displacement=max_displacement, n_surrogates=10):
self.assertEqual(int(np.sum(surrogate.to_bool_array())),
original_count)

def test_bin_shuffling_window_spike_counts_preserved(self):
"""Spikes must not cross non-overlapping window boundaries."""
max_displacement = 5 # displacement_window = 10 bins
bin_size = 10 * pq.ms # each window covers 100 ms
# 4 windows: [0-99], [100-199], [200-299], [300-399] ms
spiketrain = neo.SpikeTrain(
[15, 55, 135, 215, 265, 315, 375] * pq.ms, t_stop=400 * pq.ms)
binned = conv.BinnedSpikeTrain(spiketrain, bin_size)
bool_arr = binned.to_bool_array()[0]
window_size = 2 * max_displacement
n_windows = len(bool_arr) // window_size
orig_counts = [int(np.sum(bool_arr[i * window_size:
(i + 1) * window_size]))
for i in range(n_windows)]
for surrogate in surr.bin_shuffling(
binned, max_displacement=max_displacement, n_surrogates=20):
surr_bool = surrogate.to_bool_array()[0]
for i in range(n_windows):
surr_count = int(np.sum(
surr_bool[i * window_size:(i + 1) * window_size]))
self.assertEqual(surr_count, orig_counts[i],
f"Spike count changed in window {i}")

def test_bin_shuffling_surrogates_differ(self):
"""Multiple surrogates must not all be identical."""
spiketrain = neo.SpikeTrain(
[50, 100, 150, 200, 250] * pq.ms, t_stop=500 * pq.ms)
bin_size = 5 * pq.ms
binned = conv.BinnedSpikeTrain(spiketrain, bin_size)
n_surrogates = 10
surrogate_trains = surr.bin_shuffling(
binned, max_displacement=20, n_surrogates=n_surrogates)
bool_arrays = [s.to_bool_array()[0] for s in surrogate_trains]
all_equal = all(np.array_equal(bool_arrays[0], b)
for b in bool_arrays[1:])
self.assertFalse(all_equal,
"All surrogates produced identical spike trains")

def test_bin_shuffling_neo_spiketrain_input(self):
"""Passing neo.SpikeTrain + bin_size must return list of neo.SpikeTrain."""
spiketrain = neo.SpikeTrain(
[90, 93, 97, 100, 105, 150, 180, 350] * pq.ms, t_stop=500 * pq.ms)
bin_size = 10 * pq.ms
max_displacement = 5
n_surrogates = 3
surrogate_trains = surr.bin_shuffling(
spiketrain, max_displacement=max_displacement,
bin_size=bin_size, n_surrogates=n_surrogates)
self.assertIsInstance(surrogate_trains, list)
self.assertEqual(len(surrogate_trains), n_surrogates)
for surrogate_train in surrogate_trains:
self.assertIsInstance(surrogate_train, neo.SpikeTrain)
self.assertEqual(surrogate_train.t_start, spiketrain.t_start)
self.assertEqual(surrogate_train.t_stop, spiketrain.t_stop)
if len(surrogate_train) > 1:
assert_array_less(0., np.diff(surrogate_train.magnitude))

def test_bin_shuffling_neo_spiketrain_missing_bin_size_raises(self):
"""Passing neo.SpikeTrain without bin_size must raise ValueError."""
spiketrain = neo.SpikeTrain([50, 100, 200] * pq.ms, t_stop=500 * pq.ms)
with self.assertRaises(ValueError):
surr.bin_shuffling(spiketrain, max_displacement=5)

def test_bin_shuffling_sliding_on_neo_spiketrain_warns(self):
"""Using sliding=True with a neo.SpikeTrain input must issue UserWarning."""
spiketrain = neo.SpikeTrain([50, 100, 200] * pq.ms, t_stop=500 * pq.ms)
with self.assertWarns(UserWarning):
surr.bin_shuffling(spiketrain, max_displacement=5,
bin_size=10 * pq.ms, sliding=True)

def test_bin_shuffling_via_surrogates_wrapper(self):
"""surrogates() with method='bin_shuffling' must return neo.SpikeTrain."""
spiketrain = neo.SpikeTrain(
[90, 93, 97, 100, 105] * pq.ms, t_stop=200 * pq.ms)
bin_size = 5 * pq.ms
n_surrogates = 3
result = surr.surrogates(spiketrain, n_surrogates=n_surrogates,
method='bin_shuffling',
dt=10 * pq.ms, bin_size=bin_size)
self.assertIsInstance(result, list)
self.assertEqual(len(result), n_surrogates)
for s in result:
self.assertIsInstance(s, neo.SpikeTrain)

def test_trial_shuffling_output_format(self):
spiketrain = \
[neo.SpikeTrain([90, 93, 97, 100, 105, 150, 180, 190] * pq.ms,
t_stop=.2 * pq.s),
neo.SpikeTrain([90, 93, 97, 100, 105, 150, 180, 190] * pq.ms,
t_stop=.2 * pq.s)]
# trial_length = 200 * pq.ms
# trial_separation = 50 * pq.ms
n_surrogates = 2
dither = 10 * pq.ms
surrogate_trains = surr.trial_shifting(
Expand All @@ -646,12 +741,15 @@ def test_trial_shuffling_output_format(self):

self.assertIsInstance(surrogate_trains[0], list)
self.assertIsInstance(surrogate_trains[0][0], neo.SpikeTrain)
for surrogate_train in surrogate_trains[0]:
self.assertEqual(surrogate_train.units, spiketrain[0].units)
self.assertEqual(surrogate_train.t_start, spiketrain[0].t_start)
self.assertEqual(surrogate_train.t_stop, spiketrain[0].t_stop)
self.assertEqual(len(surrogate_train), len(spiketrain[0]))
assert_array_less(0., np.diff(surrogate_train)) # check ordering
# check all surrogates, not just the first one
for surr_idx, surrogate in enumerate(surrogate_trains):
for trial_idx, surrogate_train in enumerate(surrogate):
ref_trial = spiketrain[trial_idx]
self.assertEqual(surrogate_train.units, ref_trial.units)
self.assertEqual(surrogate_train.t_start, ref_trial.t_start)
self.assertEqual(surrogate_train.t_stop, ref_trial.t_stop)
self.assertEqual(len(surrogate_train), len(ref_trial))
assert_array_less(0., np.diff(surrogate_train))

def test_trial_shuffling_empty_train(self):

Expand Down Expand Up @@ -699,6 +797,253 @@ def test_trial_shuffling_empty_train_concatenated(self):
trial_length=trial_length, trial_separation=trial_separation)[0]
self.assertEqual(len(surrogate_train), 0)

def test_trial_shifting_spikes_within_bounds(self):
"""All surrogate spikes must stay within each trial's [t_start, t_stop]."""
trials = [
neo.SpikeTrain([10, 50, 90, 150, 180] * pq.ms, t_stop=200 * pq.ms),
neo.SpikeTrain([20, 60, 110, 160, 195] * pq.ms, t_stop=200 * pq.ms),
]
dither = 30 * pq.ms
surrogate_trains = surr.trial_shifting(trials, dither=dither,
n_surrogates=5)
for surrogate in surrogate_trains:
for trial_idx, surrogate_trial in enumerate(surrogate):
t_start = trials[trial_idx].t_start
t_stop = trials[trial_idx].t_stop
self.assertTrue(
np.all(surrogate_trial >= t_start),
f"Spike below t_start in trial {trial_idx}")
self.assertTrue(
np.all(surrogate_trial < t_stop),
f"Spike at or above t_stop in trial {trial_idx}")

def test_trial_shifting_spike_count_preserved(self):
"""Circular wrap-around must preserve the spike count of every trial."""
trials = [
neo.SpikeTrain([10, 50, 90, 150, 190] * pq.ms, t_stop=200 * pq.ms),
neo.SpikeTrain([5, 70, 130] * pq.ms, t_stop=200 * pq.ms),
]
dither = 100 * pq.ms # intentionally large to force wrap-arounds
surrogate_trains = surr.trial_shifting(trials, dither=dither,
n_surrogates=10)
for surrogate in surrogate_trains:
for trial_idx, surrogate_trial in enumerate(surrogate):
self.assertEqual(len(surrogate_trial), len(trials[trial_idx]))

def test_trial_shifting_circular_wraparound(self):
"""A spike shifted past t_stop must reappear near t_start."""
# Single spike at 195 ms in a 200 ms trial; a positive shift of +10 ms
# wraps it to ~5 ms.
trial = neo.SpikeTrain([195] * pq.ms, t_stop=200 * pq.ms)
t_start_s = trial.t_start.rescale(pq.s).magnitude
t_stop_s = trial.t_stop.rescale(pq.s).magnitude
trial_dur_s = t_stop_s - t_start_s

dither_s = 0.010 # 10 ms in seconds
spike_s = 0.195

# Replicate the internal shift with a known positive delta (+10 ms).
shifted = spike_s + dither_s # 0.205 s — beyond t_stop
wrapped = np.remainder(shifted - t_start_s, trial_dur_s) + t_start_s
self.assertGreaterEqual(wrapped, t_start_s)
self.assertLess(wrapped, t_stop_s)
# Wrapped value should be approximately 0.005 s (i.e. 5 ms)
self.assertAlmostEqual(wrapped, t_start_s + 0.005, places=10)

def test_trial_shifting_wraparound_exact(self):
"""A spike shifted past t_stop must wrap to the exact correct position.

Uses a fixed seed so the shift magnitude is known in advance.
With random.seed(0) the shift is +13.776... ms (positive), which moves
the spike at 195 ms to 208.776... ms and must wrap to 8.776... ms.
The test verifies the exact output of trial_shifting, not just the math.
"""
# setUp has already called random.seed(0).
# Pre-draw the shift, then reset so trial_shifting sees the same value.
dither_s = 0.020
expected_shift_s = dither_s * (2 * random.random() - 1)
random.seed(0)

trial = neo.SpikeTrain([60, 100, 195] * pq.ms, t_stop=200 * pq.ms)
surrogate = surr.trial_shifting(
[trial], dither=20 * pq.ms, n_surrogates=1)[0][0]

orig_s = np.array([0.060, 0.100, 0.195])
expected_s = np.sort(np.remainder(orig_s + expected_shift_s, 0.200))

# With seed 0 the shift is positive (~+13.78 ms), so 195 ms wraps to
# ~8.78 ms and must appear as the first spike in the sorted output.
self.assertLess(expected_s[0], 0.020,
"Wrapped spike should be near t_start, not t_stop")
np.testing.assert_allclose(
surrogate.rescale(pq.s).magnitude, expected_s, rtol=1e-10,
err_msg="Surrogate times do not match expected wrap-around shift")

def test_trial_shifting_sorted_output(self):
"""Surrogate spike trains must be sorted within every trial."""
trials = [
neo.SpikeTrain([30, 80, 120, 170] * pq.ms, t_stop=200 * pq.ms),
neo.SpikeTrain([15, 55, 95, 185] * pq.ms, t_stop=200 * pq.ms),
]
dither = 50 * pq.ms
surrogate_trains = surr.trial_shifting(trials, dither=dither,
n_surrogates=10)
for surrogate in surrogate_trains:
for surrogate_trial in surrogate:
assert_array_less(0., np.diff(surrogate_trial.magnitude))

def test_trial_shifting_surrogates_differ(self):
"""Multiple surrogates must not all produce identical spike trains."""
np.random.seed(None) # use true randomness to avoid seed collision
random.seed(None)
trials = [
neo.SpikeTrain([50, 100, 150] * pq.ms, t_stop=200 * pq.ms),
]
dither = 20 * pq.ms
n_surrogates = 10
surrogate_trains = surr.trial_shifting(trials, dither=dither,
n_surrogates=n_surrogates)
# Collect the first (and only) trial's spike times for each surrogate
first_spikes = [surrogate_trains[i][0].magnitude
for i in range(n_surrogates)]
# At least two surrogates must differ
all_equal = all(
np.allclose(first_spikes[0], s) for s in first_spikes[1:])
self.assertFalse(all_equal,
"All surrogates produced identical spike trains")

def test_trial_shifting_trials_shifted_independently(self):
"""Within a single surrogate, different trials receive different shifts."""
np.random.seed(None)
random.seed(None)
# Use 5 identical trials so any difference in shift is detectable.
trial = neo.SpikeTrain([50, 100, 150] * pq.ms, t_stop=200 * pq.ms)
n_trials = 5
trials = [trial.copy() for _ in range(n_trials)]
dither = 30 * pq.ms

# Run many surrogates and check that the shifts across trials differ.
found_differing = False
for _ in range(20):
surrogate_trains = surr.trial_shifting(trials, dither=dither,
n_surrogates=1)
surrogate = surrogate_trains[0]
shifts = []
for surr_trial in surrogate:
# The shift can be recovered mod trial_duration because the
# spike times are wrapped — just check they are not all the same.
shifts.append(surr_trial.magnitude[0])
if len(set(np.round(shifts, 6))) > 1:
found_differing = True
break
self.assertTrue(found_differing,
"All trials had identical shifts in every run")

def test_trial_shifting_isi_preservation(self):
"""For a trial shifted without wrap-around, ISIs must be unchanged."""
# Place spikes well inside the trial so a small dither cannot cause
# any spike to wrap around.
trial = neo.SpikeTrain([60, 80, 100, 120, 140] * pq.ms,
t_stop=200 * pq.ms)
dither = 5 * pq.ms # too small to push any spike out of [0, 200) ms

orig_isis = np.diff(trial.rescale(pq.ms).magnitude)
n_surrogates = 20
for surrogate in surr.trial_shifting([trial], dither=dither,
n_surrogates=n_surrogates):
surr_isis = np.diff(surrogate[0].rescale(pq.ms).magnitude)
np.testing.assert_allclose(
surr_isis, orig_isis,
err_msg="ISIs changed for a shift that caused no wrap-around")

def test_trial_shifting_via_surrogates_wrapper_list_input(self):
"""surrogates() with method='trial_shifting' and a list input must work."""
trials = [
neo.SpikeTrain([90, 93, 97, 100, 105] * pq.ms,
t_stop=200 * pq.ms),
neo.SpikeTrain([40, 80, 120, 160] * pq.ms,
t_stop=200 * pq.ms),
]
dt = 15 * pq.ms
n_surrogates = 3
result = surr.surrogates(trials, n_surrogates=n_surrogates,
method='trial_shifting', dt=dt)

self.assertIsInstance(result, list)
self.assertEqual(len(result), n_surrogates)
for surrogate in result:
self.assertIsInstance(surrogate, list)
self.assertEqual(len(surrogate), len(trials))
for trial_idx, surrogate_trial in enumerate(surrogate):
self.assertIsInstance(surrogate_trial, neo.SpikeTrain)
self.assertEqual(len(surrogate_trial), len(trials[trial_idx]))
self.assertEqual(surrogate_trial.t_start,
trials[trial_idx].t_start)
self.assertEqual(surrogate_trial.t_stop,
trials[trial_idx].t_stop)

def test_trial_shifting_ground_truth(self):
"""Exact spike times must match the expected circular shift for a known seed.

setUp seeds random with 0 before every test. We consume one draw to
find the shift the function will apply, reset the seed so the function
sees the same draw, then compare the output exactly.
"""
dither_s = 0.020 # 20 ms in seconds
expected_shift_s = dither_s * (2 * random.random() - 1)
random.seed(0) # reset so trial_shifting draws the same value

trial = neo.SpikeTrain([50, 80, 120] * pq.ms, t_stop=200 * pq.ms)
surrogate = surr.trial_shifting(
[trial], dither=20 * pq.ms, n_surrogates=1)[0][0]

orig_s = np.array([0.050, 0.080, 0.120])
expected_s = np.sort(
np.remainder(orig_s + expected_shift_s, 0.200))
np.testing.assert_allclose(
surrogate.rescale(pq.s).magnitude, expected_s, rtol=1e-10,
err_msg="Surrogate spike times do not match expected circular shift")

def test_trial_shifting_shift_bounded_by_dither(self):
"""Maximum displacement of any spike must not exceed dither.

Spikes are placed far from both edges so no wrap-around can occur,
making the displacement directly measurable.
"""
trial = neo.SpikeTrain([80, 100, 120] * pq.ms, t_stop=200 * pq.ms)
dither = 25 * pq.ms
dither_ms = dither.rescale(pq.ms).magnitude
orig_ms = trial.rescale(pq.ms).magnitude
for surrogate in surr.trial_shifting([trial], dither=dither,
n_surrogates=50):
surr_ms = surrogate[0].rescale(pq.ms).magnitude
displacements = np.abs(surr_ms - orig_ms)
self.assertTrue(
np.all(displacements <= dither_ms + 1e-9),
f"Displacement {displacements.max():.4f} ms exceeds "
f"dither {dither_ms:.4f} ms")

def test_trial_shifting_heterogeneous_trial_lengths(self):
"""Trials with different t_start/t_stop values are handled correctly."""
trials = [
neo.SpikeTrain([10, 50, 90] * pq.ms,
t_start=0 * pq.ms, t_stop=100 * pq.ms),
neo.SpikeTrain([210, 250, 290] * pq.ms,
t_start=200 * pq.ms, t_stop=300 * pq.ms),
neo.SpikeTrain([520, 560] * pq.ms,
t_start=500 * pq.ms, t_stop=600 * pq.ms),
]
dither = 20 * pq.ms
surrogate_trains = surr.trial_shifting(trials, dither=dither,
n_surrogates=5)
for surrogate in surrogate_trains:
for trial_idx, surrogate_trial in enumerate(surrogate):
t_start = trials[trial_idx].t_start
t_stop = trials[trial_idx].t_stop
self.assertEqual(len(surrogate_trial), len(trials[trial_idx]))
self.assertTrue(np.all(surrogate_trial >= t_start))
self.assertTrue(np.all(surrogate_trial < t_stop))


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading