Skip to content

Commit d82ae2f

Browse files
committed
fix(dyn/rates): RNN/GRU/LSTM cell reset_state crash; ThresholdLinearModel noise dt-scaling
- RNNCell/GRUCell/LSTMCell.reset_state() crashed (ValueError) in default bp.reset_state usage; build state via variable() not parameter((None,...)) (High) - ThresholdLinearModel scaled its Euler-Maruyama noise by dt instead of sqrt(dt), making the noise intensity dt-dependent (Medium) Findings recorded in docs/issues-found-20260619-dyn-rates-base.md
1 parent 5d06443 commit d82ae2f

5 files changed

Lines changed: 232 additions & 13 deletions

File tree

brainpy/dyn/rates/populations.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,23 +1043,23 @@ def update(self, inp_e=None, inp_i=None):
10431043
input_e = inp_e if (inp_e is not None) else 0.
10441044
input_i = inp_i if (inp_i is not None) else 0.
10451045

1046-
de = -self.e + self.beta_e * bm.maximum(input_e, 0.)
1046+
de = (-self.e + self.beta_e * bm.maximum(input_e, 0.)) / self.tau_e
1047+
de = self.e + de * dt
10471048
with jax.ensure_compile_time_eval():
10481049
has_noise = bm.any(self.noise_e != 0.)
1049-
10501050
if has_noise:
1051-
de += bm.random.randn(*self.varshape) * self.noise_e
1052-
de = de / self.tau_e
1053-
self.e.value = bm.maximum(self.e + de * dt, 0.)
1051+
# Euler-Maruyama: the stochastic term scales as sqrt(dt), not dt, so the
1052+
# noise intensity is independent of the integration step (P10-M1).
1053+
de += self.noise_e / self.tau_e * bm.sqrt(dt) * bm.random.randn(*self.varshape)
1054+
self.e.value = bm.maximum(de, 0.)
10541055

1055-
di = -self.i + self.beta_i * bm.maximum(input_i, 0.)
1056+
di = (-self.i + self.beta_i * bm.maximum(input_i, 0.)) / self.tau_i
1057+
di = self.i + di * dt
10561058
with jax.ensure_compile_time_eval():
10571059
has_noise = bm.any(self.noise_i != 0.)
1058-
10591060
if has_noise:
1060-
di += bm.random.randn(*self.varshape) * self.noise_i
1061-
di = di / self.tau_i
1062-
self.i.value = bm.maximum(self.i + di * dt, 0.)
1061+
di += self.noise_i / self.tau_i * bm.sqrt(dt) * bm.random.randn(*self.varshape)
1062+
self.i.value = bm.maximum(di, 0.)
10631063
return self.e.value
10641064

10651065
def clear_input(self):

brainpy/dyn/rates/rates_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
# ==============================================================================
1616
from unittest import TestCase
1717

18+
import jax.numpy as jnp
19+
import numpy as np
1820
from absl.testing import parameterized
1921

2022
import brainpy as bp
2123
import brainpy.math as bm
24+
from brainpy.context import share
2225
from brainpy.dyn.rates import populations
2326

2427

@@ -54,6 +57,43 @@ def test_tlm(self):
5457
self.assertTrue(tlm.tau_e is not None)
5558

5659

60+
class TestThresholdLinearModelNoise(TestCase):
61+
"""P10-M1: noise must follow Euler-Maruyama ``sqrt(dt)`` scaling."""
62+
63+
@staticmethod
64+
def _noise_increment_std(dt):
65+
# Drive a fresh model with no drift (beta_e=0, tau_e=1) from e=0 so that one
66+
# step gives e = max(noise_e/tau_e * sqrt(dt) * randn, 0). Measure the std of
67+
# the (clamped) increment; the positive-half std is proportional to the
68+
# increment std, so its ratio across dt isolates the dt scaling.
69+
bm.random.seed(0)
70+
bm.set_dt(dt)
71+
m = bp.rates.ThresholdLinearModel(20000, noise_e=1.0, beta_e=0.0, tau_e=1.0)
72+
m.reset_state()
73+
share.save(t=0.0, dt=dt, i=0)
74+
out = np.asarray(m.update(inp_e=0.0))
75+
pos = out[out > 0]
76+
return float(pos.std())
77+
78+
def test_threshold_linear_model_noise_scales_as_sqrt_dt(self):
79+
s_small = self._noise_increment_std(0.01)
80+
s_large = self._noise_increment_std(0.1)
81+
ratio = s_large / s_small
82+
# sqrt(dt): ratio ~ sqrt(0.1/0.01) = sqrt(10) ~ 3.162.
83+
# The buggy dt scaling gives ratio ~ 10.
84+
self.assertAlmostEqual(ratio, np.sqrt(10.0), delta=0.2)
85+
86+
def test_threshold_linear_model_noise_finite(self):
87+
bm.random.seed(0)
88+
bm.set_dt(0.1)
89+
m = bp.rates.ThresholdLinearModel(8, noise_e=1.0, noise_i=0.5)
90+
m.reset_state()
91+
share.save(t=0.0, dt=0.1, i=0)
92+
out = jnp.asarray(m.update(inp_e=1.0, inp_i=1.0))
93+
self.assertEqual(out.shape, (8,))
94+
self.assertTrue(bool(jnp.isfinite(out).all()))
95+
96+
5797
class TestPopulation(parameterized.TestCase):
5898
@parameterized.named_parameters(
5999
{'testcase_name': f'noise_of_{name}', 'neuron': name}

brainpy/dyn/rates/rnncells.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
self.state[:] = self.state2train
125125

126126
def reset_state(self, batch_or_mode=None, **kwargs):
127-
self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out,), allow_none=False)
127+
self.state.value = variable(self._state_initializer, batch_or_mode, self.num_out)
128128
if self.train_state:
129129
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
130130
self.state[:] = self.state2train
@@ -236,7 +236,7 @@ def __init__(
236236
self.state[:] = self.state2train
237237

238238
def reset_state(self, batch_or_mode=None, **kwargs):
239-
self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out), allow_none=False)
239+
self.state.value = variable(self._state_initializer, batch_or_mode, self.num_out)
240240
if self.train_state:
241241
self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False)
242242
self.state[:] = self.state2train
@@ -372,7 +372,7 @@ def __init__(
372372
self.state[:] = self.state2train
373373

374374
def reset_state(self, batch_or_mode=None, **kwargs):
375-
self.state.value = parameter(self._state_initializer, (batch_or_mode, self.num_out * 2), allow_none=False)
375+
self.state.value = variable(self._state_initializer, batch_or_mode, self.num_out * 2)
376376
if self.train_state:
377377
self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False)
378378
self.state[:] = self.state2train

brainpy/dyn/rates/rnncells_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,5 +170,38 @@ def test_Conv3dLSTMCell_NonBatching(self):
170170
output = layer(input)
171171

172172

173+
class Test_Rnncells_reset_state(parameterized.TestCase):
174+
"""Regression tests for P10-H1: ``reset_state(None)`` must not crash."""
175+
176+
@parameterized.product(cls=['RNNCell', 'GRUCell', 'LSTMCell'])
177+
def test_rnn_cells_reset_state_none_unbatched(self, cls):
178+
# P10-H1: explicit reset_state(None) on a non-batching cell used to raise
179+
# ``ValueError: Do not support type <class 'NoneType'>: None`` because the
180+
# state shape was built as ``(None, num_out)``.
181+
bm.random.seed()
182+
cell = getattr(bp.dyn, cls)(num_in=3, num_out=4)
183+
cell.reset_state(None)
184+
expected = (8,) if cls == 'LSTMCell' else (4,)
185+
self.assertTupleEqual(tuple(cell.state.shape), expected)
186+
187+
@parameterized.product(cls=['RNNCell', 'GRUCell', 'LSTMCell'])
188+
def test_rnn_cells_reset_state_via_bp_reset_state(self, cls):
189+
# P10-H1: the public ``bp.reset_state`` path passes ``batch_or_mode=None``.
190+
bm.random.seed()
191+
cell = getattr(bp.dyn, cls)(num_in=3, num_out=4)
192+
bp.reset_state(cell)
193+
expected = (8,) if cls == 'LSTMCell' else (4,)
194+
self.assertTupleEqual(tuple(cell.state.shape), expected)
195+
196+
@parameterized.product(cls=['RNNCell', 'GRUCell', 'LSTMCell'])
197+
def test_rnn_cells_reset_state_int_batch(self, cls):
198+
# The int-batch path must still produce a leading batch axis.
199+
bm.random.seed()
200+
cell = getattr(bp.dyn, cls)(num_in=3, num_out=4, mode=bm.batching_mode)
201+
cell.reset_state(2)
202+
expected = (2, 8) if cls == 'LSTMCell' else (2, 4)
203+
self.assertTupleEqual(tuple(cell.state.shape), expected)
204+
205+
173206
if __name__ == '__main__':
174207
absltest.main()
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Audit — dyn/rates + dyn/outs + dyn/others + dyn base/utils (P10)
2+
3+
Date: 2026-06-19
4+
Branch: `fix/audit-20260619-dyn-rates-base`
5+
Scope: `brainpy/dyn/rates/{nvar,populations,reservoir,rnncells}.py`,
6+
`brainpy/dyn/outs/{base,outputs}.py`, `brainpy/dyn/others/{common,input,noise}.py`,
7+
`brainpy/dyn/{base,utils,_docs}.py` (+ co-located `*_test.py`).
8+
9+
## Summary
10+
11+
A prior audit pass (2026-06-18, see `dev/issues-found-20260618.md`) already fixed
12+
the bulk of the verified Critical/High bugs in this slice and left regression tests
13+
in `brainpy/dyn/rates/dyn_rates_dynold_fixes_test.py`. Those fixes are present and
14+
green in the current tree:
15+
16+
- C-15 `ThresholdLinearModel` `randn(*shape)` — fixed (`populations.py:1051,1060`).
17+
- C-16 `StuartLandauOscillator.dy` `+w*x` rotational coupling — fixed (`populations.py:721`).
18+
- H-36 `LSTMCell` `h`/`c` setters slice the last axis — fixed (`rnncells.py:401,412`).
19+
- H-37 `Reservoir` recurrent noise `uniform(-1, 1)` — fixed (`reservoir.py:228`).
20+
- H-38 `Reservoir` bias added in `update()` — fixed (`reservoir.py:223-224`).
21+
22+
This pass performs a fresh review and fixes the still-present
23+
correctness/robustness issues (M-20, M-21), re-examines the previously-recorded
24+
M-18/M-19 (both turn out NOT to be bugs in the current brainstate 0.5 stack), and
25+
records remaining Low items.
26+
27+
---
28+
29+
### P10-H1 — `RNNCell`/`GRUCell`/`LSTMCell.reset_state()` crashes in default usage [High]
30+
- File: `brainpy/dyn/rates/rnncells.py:127, 239, 375`
31+
- Category: edge/error, api-drift
32+
- What: `reset_state` builds the state via
33+
`parameter(self._state_initializer, (batch_or_mode, self.num_out), allow_none=False)`.
34+
With the default `batch_or_mode=None` (the value supplied by `bp.reset_state(node)`
35+
and by a bare `node.reset_state()` in non-batching mode) the shape tuple becomes
36+
`(None, num_out)`, and `tools.size2num(None)` raises
37+
`ValueError: Do not support type <class 'NoneType'>: None`.
38+
- Why it's a bug: `bp.reset_state(net)` / `node.reset_state()` is the standard reset
39+
API. Any network containing an unbatched `RNNCell`/`GRUCell`/`LSTMCell` crashes on
40+
reset. `__init__` already builds the state correctly with
41+
`variable(jnp.zeros, self.mode, self.num_out)`, which handles `None`; only
42+
`reset_state` regressed to `parameter((None, ...))`.
43+
- Repro:
44+
```python
45+
import brainpy as bp
46+
cell = bp.dyn.RNNCell(num_in=3, num_out=4) # NonBatchingMode
47+
bp.reset_state(cell) # ValueError: ... None
48+
```
49+
- Fix: use `variable(self._state_initializer, batch_or_mode, self.num_out)` (matching
50+
`__init__`), which yields `(num_out,)` for `None`, `(B, num_out)` for an int `B`,
51+
and the mode-aware shape for a `Mode`. Applied to all three cells (LSTM uses
52+
`num_out * 2`).
53+
- Tests: `test_rnn_cells_reset_state_none_unbatched`,
54+
`test_rnn_cells_reset_state_via_bp_reset_state`,
55+
`test_rnn_cells_reset_state_int_batch` in `rnncells_test.py`.
56+
- Status: fixed
57+
58+
### P10-M1 — `ThresholdLinearModel` noise scales as `dt`, not `sqrt(dt)` [Medium]
59+
- File: `brainpy/dyn/rates/populations.py:1046-1062`
60+
- Category: numerics
61+
- What: the Euler update folds the Gaussian noise into the drift
62+
(`de += randn(*shape) * noise_e`, then `de = de / tau_e`, then
63+
`e = max(e + de * dt, 0.)`). The noise increment therefore scales linearly with
64+
`dt`. A correct Euler–Maruyama step for `tau de = (-e + beta·[I]_+) dt + noise·dW`
65+
needs the stochastic term to scale as `sqrt(dt)` (`dW ~ sqrt(dt)·N(0,1)`).
66+
- Why it's a bug: the effective noise intensity is `dt`-dependent — halving `dt`
67+
changes the realized noise standard deviation by 2x instead of `sqrt(2)`, so the
68+
stationary statistics of the simulated rate change with the integration step.
69+
Measured: noise std ratio for `dt=0.1` vs `dt=0.01` is exactly `10` (the `dt`
70+
scaling); the correct Euler–Maruyama ratio is `sqrt(10) ≈ 3.16`.
71+
- Repro: static + measured (see commit message / regression test).
72+
- Fix: move the noise out of the `dt`-scaled drift and add it as a separate
73+
`sqrt(dt)` Euler–Maruyama increment:
74+
`e = max(e + (-e + beta_e·[I]_+)/tau_e · dt + noise_e/tau_e · sqrt(dt)·randn, 0)`.
75+
- Tests: `test_threshold_linear_model_noise_scales_as_sqrt_dt` in `rates_test.py`.
76+
- Status: fixed
77+
78+
### P10-L1 — `FeedbackFHN.reset_state` rebinds `self.input`/`input_y` instead of `.value=` [Low] (recorded only — was M-18)
79+
- File: `brainpy/dyn/rates/populations.py:370-371`
80+
- Category: style
81+
- What: `reset_state` does `self.input = variable(...)` / `self.input_y = variable(...)`
82+
whereas the sibling rate models (`FHN`, `QIF`, `StuartLandau`, `WilsonCowan`) use
83+
`self.input.value = ...`.
84+
- Why not a bug here: under brainstate 0.5, assigning a fresh `Variable` to an
85+
attribute that already holds a `State` performs an in-place value/shape update
86+
(object identity is preserved, value resets, batched reshape works), so captured
87+
references and monitors are not broken. Verified empirically.
88+
- Fix: recorded only (consistency nit; out of Critical/High/Medium scope).
89+
- Status: recorded-only
90+
91+
### P10-L2 — Prior audit M-19 (`FeedbackFHN` delay "double-count") is not a bug [Low] (recorded only)
92+
- File: `brainpy/dyn/rates/populations.py:374`
93+
- Category: correctness (false positive in prior audit)
94+
- What: 2026-06-18 M-19 claimed that because `state_delays={'x': self.x_delay}` is
95+
registered with the integrator, querying `self.x_delay(t - self.delay)` in `dx`
96+
double-counts the delay and should be `self.x_delay(t)`.
97+
- Why it's not a bug: `state_delays` only causes the integrator to call
98+
`delay.update(new_x)` after each step (buffer maintenance). `TimeDelay.__call__`
99+
takes an **absolute time** (see its docstring: `delay(-0.5)` → value at t=-0.5), so
100+
`x_delay(t - delay)` is the correct way to read `x(t - delay)`. Querying
101+
`x_delay(t)` would return the *current* value (no delay) and destroy the feedback.
102+
Verified empirically: the query returns the historical value ~`delay` ms in the
103+
past, not the current value.
104+
- Fix: recorded only — leave `x_delay(t - self.delay)` as-is. Changing it (per the
105+
earlier audit) would introduce a regression.
106+
- Status: recorded-only
107+
108+
### P10-L3 — `OutputGroup.reset_state` signature uses `batch_size` not `batch_or_mode` [Low] (recorded only)
109+
- File: `brainpy/dyn/others/input.py:102`
110+
- Category: style
111+
- What: `OutputGroup.reset_state(self, batch_size=None, ...)` while the rest of the
112+
module (`InputGroup`, `SpikeTimeGroup`, `PoissonGroup`) uses `batch_or_mode`. The
113+
body is a no-op `pass`, so callers passing positionally still work; no functional
114+
impact.
115+
- Fix: recorded only.
116+
- Status: recorded-only
117+
118+
### P10-L4 — NumPy-doc nonconformance across rates/outs docstrings [Low] (recorded only)
119+
- File: `brainpy/dyn/rates/populations.py` (and `nvar.py`, `reservoir.py`,
120+
`outs/outputs.py`), e.g. `Parameters::`, `References::`, `See Also::` literal-block
121+
markers and bare `Reference` headings.
122+
- Category: style
123+
- What: CLAUDE.md mandates underlined NumPy-doc sections; these files use the legacy
124+
`Section::` literal-block form (matching the rest of the repo, also flagged as L-14
125+
in the 2026-06-18 audit).
126+
- Fix: recorded only (repo-wide cosmetic; out of scope).
127+
- Status: recorded-only
128+
129+
---
130+
131+
## Verified-correct (checked, no change)
132+
133+
- `NVAR` feature construction: stride/`select_ids` picks exactly `delay` time points;
134+
monomial `comb_ids` and constant/linear concatenation correct (matches 2026-06-18
135+
Appendix B).
136+
- `Reservoir` spectral-radius rescaling (`Wrec *= spectral_radius / current_sr`) is
137+
applied after connectivity masking and before sparse reduction — correct ordering.
138+
- `QIF` / `FHN` / `WilsonCowan` ODE right-hand sides match their docstrings.
139+
- `MgBlock` magnesium curve `1/(1 + [Mg]/β·exp(α(V_off - V)))` matches the documented
140+
`g_inf`; `COBA`/`CUBA` outputs correct.
141+
- `OUProcess` uses `sdeint` with constant diffusion `g = sigma` → correct `sqrt(dt)`
142+
scaling; `reset_state` initializes `x` at `mean`.
143+
- `PoissonGroup` spike probability `freqs · dt / 1000` (Hz·ms) correct.
144+
- `LSTMCell`/`GRUCell` gate equations match docstrings (forget-gate `+1` bias; GRU
145+
reset/update split) — GRU confirmed correct by 2026-06-18 Appendix B.
146+
- `get_spk_type` mode → dtype mapping correct.

0 commit comments

Comments
 (0)