Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit b09a40e

Browse files
committed
Preprend the initial state to the result of the scan loop
1 parent 1d2cc07 commit b09a40e

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,26 @@ def body_fn(carry, x):
568568
body_fn, init_carry, sequences, length=n_steps
569569
)
570570

571+
# We need to preprend the initial values so the output matches
572+
# the raw `Scan` output.
573+
if len(outer_in["mit_sot"]) > 0:
574+
results = tuple([
575+
jnp.concatenate([init[:-n_steps], out], axis=0)
576+
for init, out in zip(outer_in["mit_sot"], results)
577+
])
578+
# TODO: HERE IS THE REASON WHY test_scan_multiple_none_output
579+
elif len(outer_in["sit_sot"]) > 0:
580+
results = tuple([
581+
jnp.concatenate([init[:-n_steps], out], axis=0)
582+
for init, out in zip(outer_in["sit_sot"], results)
583+
])
584+
585+
breakpoint()
586+
571587
shared_output = tuple(last_carry["shared"])
572588
outer_outputs = results + shared_output
573589

590+
574591
if len(outer_outputs) == 1:
575592
return outer_outputs[0]
576593

0 commit comments

Comments
 (0)