Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit a1b7b5c

Browse files
committed
Enclose non-sequences in the loop body function
1 parent bf97e57 commit a1b7b5c

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def scan(*outer_inputs):
2626
raise NotImplementedError("sit-sot not supported")
2727
if len(outer_in_shared):
2828
raise NotImplementedError("shared variables not supported")
29-
if len(outer_in_non_seqs):
30-
raise NotImplementedError("non sequence are not supported")
3129

3230
# If `output_infos` is empty we need to create an empty initial carry
3331
# value with the output's shape and dtype
@@ -39,6 +37,7 @@ def scan(*outer_inputs):
3937

4038
init_carry = outer_in_sit_sot
4139
sequences = outer_in_seqs
40+
non_sequences = outer_in_non_seqs
4241

4342
def scan_inner_in_args(carry, x):
4443
"""Create an inner-input expression.
@@ -59,7 +58,7 @@ def scan_inner_in_args(carry, x):
5958
else:
6059
inner_in_sit_sot = carry
6160

62-
return sum([inner_in_seqs, inner_in_sit_sot], [])
61+
return sum([inner_in_seqs, inner_in_sit_sot, non_sequences], [])
6362

6463
def scan_new_carry(inner_outputs):
6564
"""Create a new carry expression

tests/link/jax/test_scan.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@
4848
None,
4949
lambda op: op.info.n_nit_sot > 0,
5050
),
51+
# nit-sot, non_seq
52+
(
53+
lambda c: at.as_tensor(2.0) * c,
54+
[],
55+
[{}],
56+
[at.dscalar("c")],
57+
3,
58+
[1.0],
59+
None,
60+
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
61+
),
5162
],
5263
)
5364
def test_xit_xot_types(

0 commit comments

Comments
 (0)