Skip to content

Commit dca2e80

Browse files
authored
fix(math/sparse): coo_to_csr bounds check + sparse/event/delay regression suite (#845)
fix(math/sparse): validate pre_ids in coo_to_csr; lock sparse/event/delay correctness with tests - coo_to_csr silently produced a corrupt CSR (indptr[-1] != nse) on out-of-range pre_ids; now raises a clear ValueError (Medium) - Added regression tests comparing CSR/COO/event matvec+matmat and jitconn against dense references (incl. transpose and autodiff), plus TimeDelay / LengthDelay ring-buffer modulo-wrap tests, locking in prior correctness fixes Findings recorded in docs/issues-found-20260619-math-sparse-event.md
1 parent efc7dc9 commit dca2e80

8 files changed

Lines changed: 641 additions & 1 deletion

File tree

brainpy/math/delayvars_coverage_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,40 @@ def test_update_value_none_without_target_raises():
178178
assert ld.delay_target is None
179179
with pytest.raises(ValueError):
180180
ld.update(None)
181+
182+
183+
# ---------------------------------------------------------------------------
184+
# ring-buffer correctness regressions (guards the ``% num_delay_step`` modulo
185+
# in ``TimeDelay._true_fn`` and the rotate index in ``LengthDelay``)
186+
# ---------------------------------------------------------------------------
187+
188+
def test_time_delay_ring_buffer_wraps_modulo():
189+
"""``_true_fn`` must read ``data[(idx + step) % num_delay_step]``.
190+
191+
Without the modulo, when the read index wraps past the end of the buffer JAX
192+
clamps the out-of-bounds index to the last slot and returns a stale value.
193+
We feed a long ramp (many wraps) and check the exact-step (no-interp) reads.
194+
"""
195+
dt = 0.1
196+
delay_len = 1.0 # exact multiple of dt -> exact-step (``_true_fn``) branch
197+
d = TimeDelay(bm.zeros(1), delay_len=delay_len, dt=dt, before_t0=lambda t: t)
198+
# ``num_delay_step == 11``; iterate well past one full wrap of the buffer.
199+
n = 37
200+
for i in range(n):
201+
d.update(bm.asarray([float(i)]))
202+
ct = float(d.current_time[0])
203+
last = n - 1 # the most recently stored ramp value
204+
# delay d_ms -> value stored ``round(d_ms/dt)`` steps before ``last``.
205+
for d_ms in [0.0, 0.1, 0.3, 0.5, 1.0]:
206+
got = float(d(ct - d_ms)[0])
207+
expected = last - round(d_ms / dt)
208+
assert abs(got - expected) < 1e-4, (d_ms, got, expected)
209+
210+
211+
def test_length_delay_ramp_matches_reference():
212+
for method in (ROTATE_UPDATE, CONCAT_UPDATE):
213+
d = LengthDelay(bm.zeros(1), delay_len=5, update_method=method)
214+
for i in range(23): # many wraps for the rotate buffer (len 6)
215+
d.update(bm.asarray([float(i)]))
216+
got = [float(d(k)[0]) for k in range(6)]
217+
assert got == [22 - k for k in range(6)], (method, got)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# -*- coding: utf-8 -*-
2+
"""Regression tests for ``brainpy/math/event/csr_matmat.py``.
3+
4+
Guards the event-driven (binary) CSR matmat, especially the ``transpose=True``
5+
branch (must compute ``Aᵀ @ E``), against a dense numpy reference.
6+
"""
7+
8+
import jax.numpy as jnp
9+
import numpy as np
10+
11+
import brainevent
12+
13+
from brainpy.math.event.csr_matmat import csrmm
14+
15+
16+
_ROWS = np.array([0, 0, 1, 2, 2])
17+
_COLS = np.array([1, 3, 0, 1, 3])
18+
_VALS = np.array([2., 4., 1., 3., 2.])
19+
_SHAPE = (3, 4)
20+
21+
22+
def _dense():
23+
m = np.zeros(_SHAPE, dtype=np.float32)
24+
for v, r, c in zip(_VALS, _ROWS, _COLS):
25+
m[r, c] = v
26+
return m
27+
28+
29+
def _csr():
30+
indptr, indices, order = brainevent.coo2csr(_ROWS, _COLS, shape=_SHAPE)
31+
data = jnp.asarray(_VALS)[np.asarray(order)]
32+
return data, np.asarray(indices), np.asarray(indptr)
33+
34+
35+
def test_event_csrmm_no_transpose_matches_dense():
36+
data, indices, indptr = _csr()
37+
E = np.array([[True, False], [False, True], [True, True], [False, False]])
38+
out = np.asarray(csrmm(data, indices, indptr, jnp.asarray(E), shape=_SHAPE, transpose=False))
39+
assert out.shape == (3, 2)
40+
np.testing.assert_allclose(out, _dense() @ E.astype(np.float32), rtol=1e-5, atol=1e-5)
41+
42+
43+
def test_event_csrmm_transpose_matches_dense():
44+
# transpose=True must compute Aᵀ @ E (Aᵀ is (4,3), E is (3,2)).
45+
data, indices, indptr = _csr()
46+
E = np.array([[True, False], [False, True], [True, True]])
47+
out = np.asarray(csrmm(data, indices, indptr, jnp.asarray(E), shape=_SHAPE, transpose=True))
48+
assert out.shape == (4, 2)
49+
np.testing.assert_allclose(out, _dense().T @ E.astype(np.float32), rtol=1e-5, atol=1e-5)
50+
51+
52+
def test_event_csrmm_matches_float_csrmm_with_binary_input():
53+
# An all-True event matrix multiplied by the binary path equals the dense
54+
# product restricted to the selected entries.
55+
data, indices, indptr = _csr()
56+
E = np.array([[True, True], [True, True], [True, True], [True, True]])
57+
out = np.asarray(csrmm(data, indices, indptr, jnp.asarray(E), shape=_SHAPE, transpose=False))
58+
np.testing.assert_allclose(out, _dense() @ E.astype(np.float32), rtol=1e-5, atol=1e-5)

brainpy/math/pre_syn_post_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# -*- coding: utf-8 -*-
2+
"""Regression tests for ``brainpy/math/pre_syn_post.py``.
3+
4+
Targets the event-driven CSR routing of ``pre2post_event_sum`` (which delegates
5+
to ``event.csrmv(transpose=True)``) and the empty-group / structural guards of
6+
``syn2post_mean`` and ``syn2post_softmax``.
7+
"""
8+
9+
import numpy as np
10+
11+
import brainpy.math as bm
12+
from brainpy.math.pre_syn_post import (
13+
pre2post_event_sum,
14+
syn2post_sum,
15+
syn2post_mean,
16+
syn2post_softmax,
17+
)
18+
19+
20+
# pre_num=3, post_num=4 CSR: pre0 -> {1,3}, pre1 -> {0}, pre2 -> {1,3}
21+
_INDICES = np.array([1, 3, 0, 1, 3])
22+
_INDPTR = np.array([0, 2, 3, 5])
23+
_POST_NUM = 4
24+
25+
26+
def test_pre2post_event_sum_scalar_value():
27+
events = np.array([True, False, True]) # pre 0 and 2 fire
28+
out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=1.))
29+
np.testing.assert_array_equal(out, [0., 2., 0., 2.])
30+
31+
32+
def test_pre2post_event_sum_vector_value():
33+
events = np.array([True, False, True])
34+
vals = np.array([10., 20., 30., 40., 50.])
35+
out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=vals))
36+
# pre0: post1+=10, post3+=20; pre2: post1+=40, post3+=50
37+
np.testing.assert_array_equal(out, [0., 50., 0., 70.])
38+
39+
40+
def test_pre2post_event_sum_matches_dense_transpose():
41+
# equivalent dense Aᵀ @ events, with A (pre_num x post_num) of all-ones weights
42+
events = np.array([True, True, False])
43+
A = np.zeros((3, _POST_NUM), dtype=np.float32)
44+
for pre in range(3):
45+
for j in range(_INDPTR[pre], _INDPTR[pre + 1]):
46+
A[pre, _INDICES[j]] = 1.0
47+
out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=1.))
48+
np.testing.assert_allclose(out, A.T @ events.astype(np.float32), rtol=1e-5, atol=1e-5)
49+
50+
51+
def test_syn2post_sum_matches_reference():
52+
syn = np.array([1., 2., 3., 4.])
53+
post_ids = np.array([0, 0, 2, 2])
54+
out = np.asarray(syn2post_sum(syn, post_ids, 3))
55+
np.testing.assert_array_equal(out, [3., 0., 7.])
56+
57+
58+
def test_syn2post_mean_empty_group_is_zero_not_nan():
59+
syn = np.array([2., 4., 6.])
60+
post_ids = np.array([0, 0, 2]) # group 1 is empty
61+
out = np.asarray(syn2post_mean(syn, post_ids, 3))
62+
assert not np.any(np.isnan(out))
63+
np.testing.assert_allclose(out, [3., 0., 6.], rtol=1e-6, atol=1e-6)
64+
65+
66+
def test_syn2post_softmax_normalizes_per_group():
67+
syn = np.array([1., 2., 3., 4.])
68+
post_ids = np.array([0, 0, 1, 1])
69+
out = np.asarray(syn2post_softmax(syn, post_ids, 2))
70+
# within each post group the softmax weights sum to 1
71+
np.testing.assert_allclose(out[:2].sum(), 1.0, rtol=1e-5, atol=1e-5)
72+
np.testing.assert_allclose(out[2:].sum(), 1.0, rtol=1e-5, atol=1e-5)
73+
# values match a manual softmax of [1,2] and [3,4]
74+
s01 = np.exp([1., 2.] - np.max([1., 2.])); s01 /= s01.sum()
75+
np.testing.assert_allclose(out[:2], s01, rtol=1e-5, atol=1e-5)

brainpy/math/sparse/coo_mv_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# -*- coding: utf-8 -*-
2+
"""Regression tests for ``brainpy/math/sparse/coo_mv.py``.
3+
4+
``coomv`` converts COO indices to CSR (``brainevent.coo2csr``) before delegating
5+
to ``brainevent.CSR``. These tests check both orientations and the scalar-weight
6+
broadcast path against a dense numpy reference, with unsorted COO triples (so the
7+
``coo2csr`` permutation of ``data`` is exercised).
8+
"""
9+
10+
import jax
11+
import jax.numpy as jnp
12+
import numpy as np
13+
14+
from brainpy.math.sparse.coo_mv import coomv
15+
16+
17+
# Deliberately UNSORTED COO triples for a 3 x 4 matrix:
18+
# [[0, 2, 0, 4],
19+
# [1, 0, 0, 0],
20+
# [0, 3, 0, 2]]
21+
_ROWS = np.array([2, 0, 1, 0, 2])
22+
_COLS = np.array([1, 1, 0, 3, 3])
23+
_VALS = np.array([3., 2., 1., 4., 2.])
24+
_SHAPE = (3, 4)
25+
26+
27+
def _dense():
28+
m = np.zeros(_SHAPE, dtype=np.float32)
29+
for v, r, c in zip(_VALS, _ROWS, _COLS):
30+
m[r, c] = v
31+
return m
32+
33+
34+
def test_coomv_no_transpose_matches_dense():
35+
v = jnp.arange(4, dtype=jnp.float32)
36+
out = np.asarray(coomv(_VALS, _ROWS, _COLS, v, shape=_SHAPE, transpose=False))
37+
assert out.shape == (3,)
38+
np.testing.assert_allclose(out, _dense() @ np.asarray(v), rtol=1e-5, atol=1e-5)
39+
40+
41+
def test_coomv_transpose_matches_dense():
42+
v = jnp.arange(3, dtype=jnp.float32)
43+
out = np.asarray(coomv(_VALS, _ROWS, _COLS, v, shape=_SHAPE, transpose=True))
44+
assert out.shape == (4,)
45+
np.testing.assert_allclose(out, _dense().T @ np.asarray(v), rtol=1e-5, atol=1e-5)
46+
47+
48+
def test_coomv_scalar_weight_broadcast():
49+
# scalar weight -> every stored entry uses the same value.
50+
v = jnp.arange(4, dtype=jnp.float32)
51+
out = np.asarray(coomv(2.0, _ROWS, _COLS, v, shape=_SHAPE, transpose=False))
52+
ref = np.zeros(_SHAPE, dtype=np.float32)
53+
ref[_ROWS, _COLS] = 2.0
54+
np.testing.assert_allclose(out, ref @ np.asarray(v), rtol=1e-5, atol=1e-5)
55+
56+
57+
def test_coomv_grad_scalar_weight():
58+
v = jnp.arange(4, dtype=jnp.float32)
59+
60+
def f(s):
61+
return coomv(s, _ROWS, _COLS, v, shape=_SHAPE, transpose=False).sum()
62+
63+
g = float(jax.grad(f)(2.0))
64+
# d/ds sum(A(s) @ v) = sum over stored entries of v[col]
65+
np.testing.assert_allclose(g, float(jnp.asarray(v)[_COLS].sum()), rtol=1e-5, atol=1e-5)

brainpy/math/sparse/csr_mm_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
"""Regression tests for ``brainpy/math/sparse/csr_mm.py``.
3+
4+
Guards the ``transpose=True`` branch of :func:`csrmm` (must compute ``Aᵀ @ B``,
5+
not ``B @ A``) against a dense numpy reference, including its autodiff.
6+
"""
7+
8+
import jax
9+
import jax.numpy as jnp
10+
import numpy as np
11+
12+
import brainevent
13+
14+
from brainpy.math.sparse.csr_mm import csrmm
15+
16+
17+
# 3 x 4 sparse matrix:
18+
# [[0, 2, 0, 4],
19+
# [1, 0, 0, 0],
20+
# [0, 3, 0, 2]]
21+
_ROWS = np.array([0, 0, 1, 2, 2])
22+
_COLS = np.array([1, 3, 0, 1, 3])
23+
_VALS = np.array([2., 4., 1., 3., 2.])
24+
_SHAPE = (3, 4)
25+
26+
27+
def _dense():
28+
m = np.zeros(_SHAPE, dtype=np.float32)
29+
for v, r, c in zip(_VALS, _ROWS, _COLS):
30+
m[r, c] = v
31+
return m
32+
33+
34+
def _csr():
35+
indptr, indices, order = brainevent.coo2csr(_ROWS, _COLS, shape=_SHAPE)
36+
data = jnp.asarray(_VALS)[np.asarray(order)]
37+
return data, np.asarray(indices), np.asarray(indptr)
38+
39+
40+
def test_csrmm_no_transpose_matches_dense():
41+
data, indices, indptr = _csr()
42+
B = jnp.arange(4 * 2, dtype=jnp.float32).reshape(4, 2)
43+
out = np.asarray(csrmm(data, indices, indptr, B, shape=_SHAPE, transpose=False))
44+
assert out.shape == (3, 2)
45+
np.testing.assert_allclose(out, _dense() @ np.asarray(B), rtol=1e-5, atol=1e-5)
46+
47+
48+
def test_csrmm_transpose_matches_dense():
49+
# transpose=True must compute Aᵀ @ B, where Aᵀ is (4, 3) and B is (3, 2).
50+
data, indices, indptr = _csr()
51+
B = jnp.arange(3 * 2, dtype=jnp.float32).reshape(3, 2)
52+
out = np.asarray(csrmm(data, indices, indptr, B, shape=_SHAPE, transpose=True))
53+
assert out.shape == (4, 2)
54+
np.testing.assert_allclose(out, _dense().T @ np.asarray(B), rtol=1e-5, atol=1e-5)
55+
56+
57+
def test_csrmm_transpose_grad_matches_dense():
58+
data, indices, indptr = _csr()
59+
B = jnp.arange(3 * 2, dtype=jnp.float32).reshape(3, 2)
60+
61+
def f(d):
62+
return csrmm(d, indices, indptr, B, shape=_SHAPE, transpose=True).sum()
63+
64+
g = np.asarray(jax.grad(f)(_csr()[0]))
65+
# dense reference gradient wrt the stored values
66+
dense_ref = _dense()
67+
68+
def fd(flat):
69+
m = jnp.zeros(_SHAPE, dtype=jnp.float32)
70+
m = m.at[_ROWS, _COLS].set(flat)
71+
return (m.T @ B).sum()
72+
73+
# values in CSR order correspond to coo2csr ``order``
74+
_, _, order = brainevent.coo2csr(_ROWS, _COLS, shape=_SHAPE)
75+
g_ref = np.asarray(jax.grad(fd)(jnp.asarray(_VALS)))[np.asarray(order)]
76+
np.testing.assert_allclose(g, g_ref, rtol=1e-5, atol=1e-5)

brainpy/math/sparse/utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,57 @@ def coo_to_csr(
3434
*,
3535
num_row: int
3636
) -> Tuple[jnp.ndarray, jnp.ndarray]:
37-
"""convert pre_ids, post_ids to (indices, indptr)."""
37+
"""Convert COO ``(pre_ids, post_ids)`` connectivity to CSR ``(indices, indptr)``.
38+
39+
Parameters
40+
----------
41+
pre_ids : ndarray
42+
Row (pre-synaptic) index of each non-zero entry. Every value must be in
43+
``[0, num_row)``.
44+
post_ids : ndarray
45+
Column (post-synaptic) index of each non-zero entry, aligned with
46+
``pre_ids``.
47+
num_row : int
48+
Number of rows of the sparse matrix (``shape[0]``).
49+
50+
Returns
51+
-------
52+
indices : ndarray
53+
CSR column indices of shape ``(nse,)``.
54+
indptr : ndarray
55+
CSR row pointers of shape ``(num_row + 1,)`` and dtype ``int32``.
56+
57+
Raises
58+
------
59+
ValueError
60+
If any ``pre_ids`` falls outside ``[0, num_row)``. Such an entry would
61+
otherwise be silently dropped from ``indptr`` (its scatter index is
62+
out-of-bounds), producing a structurally invalid CSR in which
63+
``indptr[-1] != len(indices)``.
64+
65+
Notes
66+
-----
67+
This is an eager preprocessing helper: it relies on ``jnp.unique`` (whose
68+
output size is data-dependent) and therefore cannot be traced under
69+
``jit``/``vmap``.
70+
"""
3871
pre_ids = as_jax(pre_ids)
3972
post_ids = as_jax(post_ids)
4073

74+
# Validate the pre (row) indices eagerly. An out-of-range ``pre_id`` would be
75+
# silently dropped by the out-of-bounds ``.at[].set`` scatter below, yielding
76+
# a corrupt CSR (``indptr[-1] != nse``) instead of an error. ``coo_to_csr``
77+
# already cannot be ``jit``-traced (``jnp.unique``), so this concrete check
78+
# does not regress any JAX transformation behaviour.
79+
if pre_ids.size > 0:
80+
pre_min = int(jnp.min(pre_ids))
81+
pre_max = int(jnp.max(pre_ids))
82+
if pre_min < 0 or pre_max >= num_row:
83+
raise ValueError(
84+
f'"pre_ids" must lie in [0, num_row) = [0, {num_row}), '
85+
f'but got values in [{pre_min}, {pre_max}].'
86+
)
87+
4188
# sorting
4289
sort_ids = jnp.argsort(pre_ids, stable=True)
4390
post_ids = post_ids[sort_ids]

0 commit comments

Comments
 (0)