Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 973ec08

Browse files
committed
Support while loops in the JAX dispatcher
1 parent 1e8c7fe commit 973ec08

File tree

2 files changed

+70
-11
lines changed

2 files changed

+70
-11
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,61 @@ def parse_outer_inputs(outer_inputs):
3636
return outer_in
3737

3838
if op.info.as_while:
39-
raise NotImplementedError("While loops are not supported in the JAX backend.")
39+
# The inner function returns a boolean as the last value.
40+
return make_jax_while_fn(scan_inner_fn, parse_outer_inputs, input_taps)
4041
else:
41-
return make_jax_scan_fn(
42-
scan_inner_fn,
43-
parse_outer_inputs,
44-
input_taps,
45-
)
42+
return make_jax_scan_fn(scan_inner_fn, parse_outer_inputs, input_taps)
43+
44+
45+
def make_jax_while_fn(
46+
scan_inner_fn: Callable,
47+
parse_outer_inputs: Callable[[TensorVariable], Dict[str, List[TensorVariable]]],
48+
input_taps: Dict,
49+
):
50+
"""Create a `jax.lax.while_loop` function to perform `Scan` computations when it
51+
is used as while loop.
52+
53+
`jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
54+
must return a value of the same type (Pytree structure, shape and dtype of
55+
the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
56+
current value and returns a boolean that indicates whether to keep iterating
57+
or not.
58+
59+
The JAX `while_loop` needs to perform the following operations:
60+
61+
1. Extract the inner-inputs;
62+
2. Build the initial carry value;
63+
3. Inside the loop:
64+
1. `carry` -> inner-inputs;
65+
2. inner-outputs -> `carry`
66+
4. Post-process the `carry` storage and return outputs
67+
"""
68+
69+
def build_while_carry(outer_in):
70+
return outer_in
71+
72+
def while_loop(*outer_inputs):
73+
74+
outer_in = parse_outer_inputs(outer_inputs)
75+
init_carry = build_while_carry(outer_in)
76+
77+
def cond_fn(carry):
78+
# The inner-function of `Scan` returns a boolean as the last
79+
# value. This needs to be included in `carry`.
80+
# TODO: Will it return `False` if the number of steps is exceeded?
81+
return carry["do_continue"]
82+
83+
def body_fn(carry):
84+
return carry
85+
86+
carry = jax.lax.while_loop(body_fn, cond_fn, init_carry)
87+
88+
# Here post-process the result
89+
outer_outputs = carry
90+
91+
return outer_outputs
92+
93+
return while_loop
4694

4795

4896
def make_jax_scan_fn(
@@ -58,7 +106,8 @@ def make_jax_scan_fn(
58106
stacked to the previous outputs. We use this to our advantage to build
59107
`Scan` outputs without having to post-process the storage arrays.
60108
61-
The JAX scan function needs to perform the following operations:
109+
The JAX `scan` function needs to perform the following operations:
110+
62111
1. Extract the inner-inputs;
63112
2. Build the initial `carry` and `sequence` values;
64113
3. Inside the loop:
@@ -265,11 +314,11 @@ def body_fn(carry, x):
265314
)
266315

267316
shared_output = tuple(last_carry["shared"])
268-
results = results + shared_output
317+
outer_outputs = results + shared_output
269318

270-
if len(results) == 1:
271-
return results[0]
319+
if len(outer_outputs) == 1:
320+
return outer_outputs[0]
272321

273-
return results
322+
return outer_outputs
274323

275324
return scan

tests/link/jax/test_scan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aesara.link.jax.linker import JAXLinker
1212
from aesara.scan.basic import scan
1313
from aesara.scan.op import Scan
14+
from aesara.scan.utils import until
1415
from aesara.tensor.math import gammaln, log
1516
from aesara.tensor.random.utils import RandomStream
1617
from aesara.tensor.type import ivector, lscalar, scalar
@@ -24,6 +25,15 @@
2425
jax_mode = Mode(JAXLinker(), opts)
2526

2627

28+
def test_while():
29+
res, updates = scan(
30+
fn=lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
31+
outputs_info=[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
32+
n_steps=3,
33+
)
34+
jax_fn = function((), res, updates=updates, mode="JAX")
35+
36+
2737
def test_sit_sot():
2838
a_at = at.scalar("a", dtype="floatX")
2939

0 commit comments

Comments
 (0)