Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,8 @@ def generate_single_fake_waveform(
positive_amplitude=(0.1, 0.25),
smooth_ms=(0.03, 0.07),
spatial_decay=(10.0, 45.0),
spatial_base=(0.1, 1.0),
spatial_power=(2.0, 4.0),
propagation_speed=(250.0, 350.0), # um / ms
ellipse_shrink=(0.4, 1),
ellipse_angle=(0, np.pi * 2),
Expand Down Expand Up @@ -1701,6 +1703,7 @@ def generate_templates(
upsample_factor=None,
unit_params=None,
mode="ellipsoid",
spatial_profile="exponential",
):
"""
Generate some templates from the given channel positions and neuron positions.
Expand Down Expand Up @@ -1741,6 +1744,8 @@ def generate_templates(
* "positive_amplitude": the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1)
* "smooth_ms": the gaussian smooth in ms (default range: (0.03-0.07))
* "spatial_decay": the spatial constant (default range: (20-40))
* "spatial_base": denominator term for power spatial decay decay (default range: 0.1-1.0)
* "spatial_power": exponent for power spatial decay decay (default range: 2-4)
* "propagation_speed": mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)).

Values can be:
Expand All @@ -1751,10 +1756,11 @@ def generate_templates(
Method used to calculate the distance between unit and channel location.
Ellipsoid injects some anisotropy dependent on unit shape, sphere is equivalent
to Euclidean distance.
spatial_profile : "exponential" | "power", default: "exponential"
Spatial footpring decay curve family.

mode : "sphere" | "ellipsoid", default: "ellipsoid"
Mode for how to calculate distances

* "exponential": alpha * exp(-r / spatial_decay)
* "power": alpha / (spatial_base + (r/spatial_decay) ** spatial_power)

Returns
-------
Expand All @@ -1765,7 +1771,6 @@ def generate_templates(

"""
unit_params = unit_params or dict()
rng = np.random.default_rng(seed=seed)

# neuron location must be 3D
assert units_locations.shape[1] == 3
Expand Down Expand Up @@ -1831,7 +1836,12 @@ def generate_templates(
z_angle=params["ellipse_angle"][u],
)

channel_factors = alpha * np.exp(-distances / spatial_decay)
if spatial_profile == "exponential":
channel_factors = alpha * np.exp(-distances / spatial_decay)
elif spatial_profile == "power":
channel_factors = (distances / spatial_decay) ** params["spatial_power"][u]
channel_factors = alpha / (params["spatial_base"][u] + channel_factors)

wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :]

# This mimic a propagation delay for distant channel
Expand Down
20 changes: 20 additions & 0 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,26 @@ def test_generate_templates():
dtype="float32",
unit_params=dict(alpha=np.ones(num_units) * 500.0, smooth_ms=(0.04, 0.05)),
)
assert templates.ndim == 3
assert templates.shape[2] == num_chans
assert templates.shape[0] == num_units

# power case
templates = generate_templates(
channel_locations,
unit_locations,
sampling_frequency,
ms_before,
ms_after,
upsample_factor=None,
seed=42,
dtype="float32",
spatial_profile="power",
unit_params=dict(alpha=np.ones(num_units) * 500.0, smooth_ms=(0.04, 0.05)),
)
assert templates.ndim == 3
assert templates.shape[2] == num_chans
assert templates.shape[0] == num_units

# upsampling case
templates = generate_templates(
Expand Down
Loading