Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit dd6c260

Browse files
committed
Return only the last value of the shared variables
1 parent 971645b commit dd6c260

File tree

1 file changed

+39
-2
lines changed

1 file changed

+39
-2
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,50 @@ def scan_new_carry(carry, inner_outputs):
177177

178178
return new_carry
179179

180+
def scan_new_outputs(inner_outputs):
181+
"""Create a new outer-output value from the inner-output value.
182+
183+
Outer-outputs are ordered as follows:
184+
- mit-mot-outputs
185+
- mit-sot-outputs
186+
- sit-sot-outputs
187+
- nit-sots
188+
- shared-outputs
189+
190+
The shared output corresponds to the last value found in the last
191+
carry value returned by `jax.lax.scan`. It is thus not returned in
192+
the body function.
193+
194+
"""
195+
outer_outputs = []
196+
if "mit_sot" in inner_output_idx:
197+
outer_outputs.append(
198+
[inner_outputs[idx] for idx in inner_output_idx["mit_sot"]]
199+
)
200+
if "sit_sot" in inner_output_idx:
201+
outer_outputs.append(
202+
[inner_outputs[idx] for idx in inner_output_idx["sit_sot"]]
203+
)
204+
if "nit_sot" in inner_output_idx:
205+
outer_outputs.append(
206+
[inner_outputs[idx] for idx in inner_output_idx["nit_sot"]]
207+
)
208+
209+
return tuple(sum(outer_outputs, []))
210+
180211
def body_fn(carry, x):
181212
inner_in_args = scan_inner_in_args(carry, x)
182213
inner_outputs = scan_inner_func(*inner_in_args)
183214
new_carry = scan_new_carry(carry, inner_outputs)
184-
return new_carry, inner_outputs
215+
outer_outputs = scan_new_outputs(inner_outputs)
216+
return new_carry, outer_outputs
217+
218+
last_carry, results = jax.lax.scan(
219+
body_fn, init_carry, sequences, length=n_steps
220+
)
185221

186-
_, results = jax.lax.scan(body_fn, init_carry, sequences, length=n_steps)
222+
shared_output = tuple(last_carry["shared"])
223+
results = results + shared_output
187224

188225
if len(results) == 1:
189226
return results[0]

0 commit comments

Comments
 (0)