Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 2fbcea5

Browse files
committed
Add test for Scan dispatcher with a RandomVariable
1 parent dd6c260 commit 2fbcea5

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

tests/link/jax/test_scan.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ def test_sit_sot():
3838
assert np.allclose(fn(1.0), jax_fn(1.0))
3939

4040

41-
@pytest.mark.xfail(
42-
reason="Returns correct results but raises exception due to stucture of shared variable."
43-
)
4441
def test_nit_sot_shared():
4542
res, updates = scan(
4643
fn=lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
@@ -50,9 +47,12 @@ def test_nit_sot_shared():
5047
)
5148

5249
jax_fn = function((), res, updates=updates, mode="JAX")
53-
print(jax_fn())
50+
res_jax = jax_fn()
5451
fn = function((), res, updates=updates)
55-
print(fn())
52+
res = fn()
53+
54+
assert res_jax.shape == res.shape
55+
assert not np.all(res_jax == res_jax[0])
5656

5757

5858
def test_mit_sot():
@@ -80,8 +80,6 @@ def test_mit_sot_2():
8080
)
8181
jax_fn = function((), res, updates=updates, mode="JAX")
8282
fn = function((), res, updates=updates)
83-
print(jax_fn())
84-
print(fn())
8583
assert np.allclose(fn(), jax_fn())
8684

8785

0 commit comments

Comments
 (0)