|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""Regression tests for ``brainpy/math/pre_syn_post.py``. |
| 3 | +
|
| 4 | +Targets the event-driven CSR routing of ``pre2post_event_sum`` (which delegates |
| 5 | +to ``event.csrmv(transpose=True)``) and the empty-group / structural guards of |
| 6 | +``syn2post_mean`` and ``syn2post_softmax``. |
| 7 | +""" |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +import brainpy.math as bm |
| 12 | +from brainpy.math.pre_syn_post import ( |
| 13 | + pre2post_event_sum, |
| 14 | + syn2post_sum, |
| 15 | + syn2post_mean, |
| 16 | + syn2post_softmax, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +# pre_num=3, post_num=4 CSR: pre0 -> {1,3}, pre1 -> {0}, pre2 -> {1,3} |
| 21 | +_INDICES = np.array([1, 3, 0, 1, 3]) |
| 22 | +_INDPTR = np.array([0, 2, 3, 5]) |
| 23 | +_POST_NUM = 4 |
| 24 | + |
| 25 | + |
| 26 | +def test_pre2post_event_sum_scalar_value(): |
| 27 | + events = np.array([True, False, True]) # pre 0 and 2 fire |
| 28 | + out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=1.)) |
| 29 | + np.testing.assert_array_equal(out, [0., 2., 0., 2.]) |
| 30 | + |
| 31 | + |
| 32 | +def test_pre2post_event_sum_vector_value(): |
| 33 | + events = np.array([True, False, True]) |
| 34 | + vals = np.array([10., 20., 30., 40., 50.]) |
| 35 | + out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=vals)) |
| 36 | + # pre0: post1+=10, post3+=20; pre2: post1+=40, post3+=50 |
| 37 | + np.testing.assert_array_equal(out, [0., 50., 0., 70.]) |
| 38 | + |
| 39 | + |
| 40 | +def test_pre2post_event_sum_matches_dense_transpose(): |
| 41 | + # equivalent dense Aᵀ @ events, with A (pre_num x post_num) of all-ones weights |
| 42 | + events = np.array([True, True, False]) |
| 43 | + A = np.zeros((3, _POST_NUM), dtype=np.float32) |
| 44 | + for pre in range(3): |
| 45 | + for j in range(_INDPTR[pre], _INDPTR[pre + 1]): |
| 46 | + A[pre, _INDICES[j]] = 1.0 |
| 47 | + out = np.asarray(pre2post_event_sum(events, (_INDICES, _INDPTR), _POST_NUM, values=1.)) |
| 48 | + np.testing.assert_allclose(out, A.T @ events.astype(np.float32), rtol=1e-5, atol=1e-5) |
| 49 | + |
| 50 | + |
| 51 | +def test_syn2post_sum_matches_reference(): |
| 52 | + syn = np.array([1., 2., 3., 4.]) |
| 53 | + post_ids = np.array([0, 0, 2, 2]) |
| 54 | + out = np.asarray(syn2post_sum(syn, post_ids, 3)) |
| 55 | + np.testing.assert_array_equal(out, [3., 0., 7.]) |
| 56 | + |
| 57 | + |
| 58 | +def test_syn2post_mean_empty_group_is_zero_not_nan(): |
| 59 | + syn = np.array([2., 4., 6.]) |
| 60 | + post_ids = np.array([0, 0, 2]) # group 1 is empty |
| 61 | + out = np.asarray(syn2post_mean(syn, post_ids, 3)) |
| 62 | + assert not np.any(np.isnan(out)) |
| 63 | + np.testing.assert_allclose(out, [3., 0., 6.], rtol=1e-6, atol=1e-6) |
| 64 | + |
| 65 | + |
| 66 | +def test_syn2post_softmax_normalizes_per_group(): |
| 67 | + syn = np.array([1., 2., 3., 4.]) |
| 68 | + post_ids = np.array([0, 0, 1, 1]) |
| 69 | + out = np.asarray(syn2post_softmax(syn, post_ids, 2)) |
| 70 | + # within each post group the softmax weights sum to 1 |
| 71 | + np.testing.assert_allclose(out[:2].sum(), 1.0, rtol=1e-5, atol=1e-5) |
| 72 | + np.testing.assert_allclose(out[2:].sum(), 1.0, rtol=1e-5, atol=1e-5) |
| 73 | + # values match a manual softmax of [1,2] and [3,4] |
| 74 | + s01 = np.exp([1., 2.] - np.max([1., 2.])); s01 /= s01.sum() |
| 75 | + np.testing.assert_allclose(out[:2], s01, rtol=1e-5, atol=1e-5) |
0 commit comments