diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1c9ece728f..719e50c2da 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1658,6 +1658,8 @@ def generate_single_fake_waveform( positive_amplitude=(0.1, 0.25), smooth_ms=(0.03, 0.07), spatial_decay=(10.0, 45.0), + spatial_base=(0.1, 1.0), + spatial_power=(2.0, 4.0), propagation_speed=(250.0, 350.0), # um / ms ellipse_shrink=(0.4, 1), ellipse_angle=(0, np.pi * 2), @@ -1701,6 +1703,7 @@ def generate_templates( upsample_factor=None, unit_params=None, mode="ellipsoid", + spatial_profile="exponential", ): """ Generate some templates from the given channel positions and neuron positions. @@ -1741,6 +1744,8 @@ def generate_templates( * "positive_amplitude": the positive amplitude in a.u. (default range: (0.05-0.15)) (negative is always -1) * "smooth_ms": the gaussian smooth in ms (default range: (0.03-0.07)) * "spatial_decay": the spatial constant (default range: (20-40)) + * "spatial_base": denominator term for power spatial decay decay (default range: 0.1-1.0) + * "spatial_power": exponent for power spatial decay decay (default range: 2-4) * "propagation_speed": mimic a propagation delay with a kind of a "speed" (default range: (250., 350.)). Values can be: @@ -1751,10 +1756,11 @@ def generate_templates( Method used to calculate the distance between unit and channel location. Ellipsoid injects some anisotropy dependent on unit shape, sphere is equivalent to Euclidean distance. + spatial_profile : "exponential" | "power", default: "exponential" + Spatial footpring decay curve family. - mode : "sphere" | "ellipsoid", default: "ellipsoid" - Mode for how to calculate distances - + * "exponential": alpha * exp(-r / spatial_decay) + * "power": alpha / (spatial_base + (r/spatial_decay) ** spatial_power) Returns ------- @@ -1765,7 +1771,6 @@ def generate_templates( """ unit_params = unit_params or dict() - rng = np.random.default_rng(seed=seed) # neuron location must be 3D assert units_locations.shape[1] == 3 @@ -1831,7 +1836,12 @@ def generate_templates( z_angle=params["ellipse_angle"][u], ) - channel_factors = alpha * np.exp(-distances / spatial_decay) + if spatial_profile == "exponential": + channel_factors = alpha * np.exp(-distances / spatial_decay) + elif spatial_profile == "power": + channel_factors = (distances / spatial_decay) ** params["spatial_power"][u] + channel_factors = alpha / (params["spatial_base"][u] + channel_factors) + wfs = wf[:, np.newaxis] * channel_factors[np.newaxis, :] # This mimic a propagation delay for distant channel diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index e0e28d09cd..da82ba9267 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -463,6 +463,26 @@ def test_generate_templates(): dtype="float32", unit_params=dict(alpha=np.ones(num_units) * 500.0, smooth_ms=(0.04, 0.05)), ) + assert templates.ndim == 3 + assert templates.shape[2] == num_chans + assert templates.shape[0] == num_units + + # power case + templates = generate_templates( + channel_locations, + unit_locations, + sampling_frequency, + ms_before, + ms_after, + upsample_factor=None, + seed=42, + dtype="float32", + spatial_profile="power", + unit_params=dict(alpha=np.ones(num_units) * 500.0, smooth_ms=(0.04, 0.05)), + ) + assert templates.ndim == 3 + assert templates.shape[2] == num_chans + assert templates.shape[0] == num_units # upsampling case templates = generate_templates(