Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 4df1d5a

Browse files
committed
Add test with fixed number of steps
1 parent b40dd44 commit 4df1d5a

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/link/jax/test_scan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from packaging.version import parse as version_parse
44

5+
import aesara
56
import aesara.tensor as at
67
from aesara.configdefaults import config
78
from aesara.graph.fg import FunctionGraph
@@ -14,6 +15,28 @@
1415
jax = pytest.importorskip("jax")
1516

1617

18+
def test_jax_fixed_length():
19+
"""
20+
21+
In this test we make sure that we can compile a scan loop
22+
with a fixed number of steps.
23+
24+
"""
25+
A = at.vector("A")
26+
27+
result, _ = aesara.scan(
28+
fn=lambda prior_result, A: prior_result * A,
29+
outputs_info=at.ones_like(A),
30+
non_sequences=A,
31+
n_steps=4,
32+
)
33+
34+
final_result = result
35+
36+
out_fg = FunctionGraph([A], [final_result])
37+
compare_jax_and_py(out_fg, [np.array([2.0])])
38+
39+
1740
@pytest.mark.xfail(
1841
version_parse(jax.__version__) >= version_parse("0.2.12"),
1942
reason="Omnistaging cannot be disabled",

0 commit comments

Comments
 (0)