Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 558ade1

Browse files
committed
Add test for scan with multiple None outputs
1 parent 6204b04 commit 558ade1

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/link/jax/test_scan.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,29 @@ def input_step_fn(y_tm1, y_tm3, a):
356356

357357
test_input_vals = [np.array(10.0).astype(config.floatX)]
358358
compare_jax_and_py(out_fg, test_input_vals)
359+
360+
361+
def test_scan_multiple_none_output():
362+
A = at.dvector("A")
363+
364+
def power_step(prior_result, x):
365+
return prior_result * x, prior_result * x * x, prior_result * x * x * x
366+
367+
result, _ = scan(
368+
power_step,
369+
non_sequences=[A],
370+
outputs_info=[at.ones_like(A), None, None],
371+
n_steps=3,
372+
)
373+
374+
FunctionGraph([A], result)
375+
test_input_vals = (np.array([1.0, 2.0]),)
376+
377+
jax_fn = function((A,), result, mode="JAX")
378+
jax_res = jax_fn(*test_input_vals)
379+
380+
fn = function((A,), result)
381+
res = fn(*test_input_vals)
382+
383+
for output_jax, output in zip(jax_res, res):
384+
assert np.allclose(jax_res, res)

0 commit comments

Comments
 (0)