Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit e38f768

Browse files
committed
Fix dynamic slicing of arrays
1 parent 886f808 commit e38f768

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,13 @@ def jax_inner_func(carry, x):
145145
# match the raw `Scan` `Op` output and, thus, work with a downstream
146146
# `Subtensor` `Op` introduced by the `scan` helper function.
147147
def append_scan_out(scan_in_part, scan_out_part):
148-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
148+
start_indices = [0] * scan_in_part.ndim
149+
slice_sizes = list(scan_in_part.shape)
150+
slice_sizes[0] = slice_sizes[0] - n_steps
151+
scan_in_part_sliced = jax.lax.dynamic_slice(
152+
scan_in_part, start_indices, slice_sizes
153+
)
154+
return jnp.concatenate([scan_in_part_sliced, scan_out_part], axis=0)
149155

150156
if scan_args.outer_in_mit_sot:
151157
scan_out_final = [

0 commit comments

Comments
 (0)