Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 3de5fb3

Browse files
committed
Support while loops in the JAX dispatcher
1 parent 1e8c7fe commit 3de5fb3

File tree

2 files changed

+261
-33
lines changed

2 files changed

+261
-33
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 230 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,232 @@ def parse_outer_inputs(outer_inputs):
3535

3636
return outer_in
3737

38+
breakpoint()
3839
if op.info.as_while:
39-
raise NotImplementedError("While loops are not supported in the JAX backend.")
40+
# The inner function returns a boolean as the last value.
41+
return make_jax_while_fn(scan_inner_fn, parse_outer_inputs, input_taps)
4042
else:
41-
return make_jax_scan_fn(
42-
scan_inner_fn,
43-
parse_outer_inputs,
44-
input_taps,
45-
)
43+
return make_jax_scan_fn(scan_inner_fn, parse_outer_inputs, input_taps)
44+
45+
46+
def make_jax_while_fn(
47+
scan_inner_fn: Callable,
48+
parse_outer_inputs: Callable[[TensorVariable], Dict[str, List[TensorVariable]]],
49+
input_taps: Dict,
50+
):
51+
"""Create a `jax.lax.while_loop` function to perform `Scan` computations when it
52+
is used as while loop.
53+
54+
`jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
55+
must return a value of the same type (Pytree structure, shape and dtype of
56+
the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
57+
current value and returns a boolean that indicates whether to keep iterating
58+
or not.
59+
60+
The JAX `while_loop` needs to perform the following operations:
61+
62+
1. Extract the inner-inputs;
63+
2. Build the initial carry value;
64+
3. Inside the loop:
65+
1. `carry` -> inner-inputs;
66+
2. inner-outputs -> `carry`
67+
4. Post-process the `carry` storage and return outputs
68+
"""
69+
70+
def build_while_carry(outer_in):
71+
"""Build the inputs to `jax.lax.scan` from the outer-inputs."""
72+
init_carry = {
73+
name: outer_in[name]
74+
for name in ["mit_sot", "sit_sot", "nit_sot", "shared", "sequences", "non_sequences"]
75+
}
76+
init_carry["step"] = 0
77+
return init_carry
78+
79+
def build_inner_outputs_map(outer_in):
80+
"""Map the inner-output variables to their position in the tuple returned by the inner function.
81+
82+
TODO: Copied from the scan builder
83+
84+
Inner-outputs are ordered as follow:
85+
- mit-mot-outputs
86+
- mit-sot-outputs
87+
- sit-sot-outputs
88+
- nit-sots (no carry)
89+
- shared-outputs
90+
[+ while-condition]
91+
92+
"""
93+
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
94+
95+
offset = 0
96+
inner_output_idx = defaultdict(list)
97+
for name in inner_outputs_names:
98+
num_outputs = len(outer_in[name])
99+
for i in range(num_outputs):
100+
inner_output_idx[name].append(offset + i)
101+
offset += num_outputs
102+
103+
return inner_output_idx
104+
105+
def from_carry_storage(carry, step, input_taps):
106+
"""Fetch the inner inputs from the values stored in the carry array.
107+
108+
`Scan` passes storage arrays as inputs, which are then read from and
109+
updated in the loop body. At each step we need to read from this array
110+
the inputs that will be passed to the inner function.
111+
112+
This mechanism is necessary because we handle multiple-input taps within
113+
the `scan` instead of letting users manage the memory in the use cases
114+
where this is necessary.
115+
116+
TODO: Copied from the scan builder
117+
118+
"""
119+
120+
def fetch(carry, step, offset):
121+
return carry[step + offset]
122+
123+
inner_inputs = []
124+
for taps, carry_element in zip(input_taps, carry):
125+
storage_size = -min(taps)
126+
offsets = [storage_size + tap for tap in taps]
127+
inner_inputs.append(
128+
[fetch(carry_element, step, offset) for offset in offsets]
129+
)
130+
131+
return sum(inner_inputs, [])
132+
133+
def to_carry_storage(inner_outputs, carry, step, input_taps):
134+
"""Create the new carry array from the inner output
135+
136+
`Scan` passes storage arrays as inputs, which are then read from and
137+
updated in the loop body. At each step we need to update this array
138+
with the outputs of the inner function
139+
140+
TODO: Copied from the scan builder
141+
142+
"""
143+
new_carry_element = []
144+
for taps, carry_element, output in zip(input_taps, carry, inner_outputs):
145+
new_carry_element.append(
146+
[carry_element.at[step - tap].set(output) for tap in taps]
147+
)
148+
149+
return sum(new_carry_element, [])
150+
151+
def while_loop(*outer_inputs):
152+
153+
outer_in = parse_outer_inputs(outer_inputs)
154+
init_carry = build_while_carry(outer_in)
155+
inner_output_idx = build_inner_outputs_map(outer_in)
156+
157+
def inner_inputs_from_carry(carry):
158+
"""Get inner-inputs from the arguments passed to the `jax.lax.while_loop` body function.
159+
160+
Inner-inputs are ordered as follows:
161+
- sequences
162+
- mit-mot inputs
163+
- mit-sot inputs
164+
- sit-sot inputs
165+
- shared-inputs
166+
- non-sequences
167+
168+
"""
169+
current_step = carry["step"]
170+
171+
inner_in_seqs = carry["sequences"][current_step]
172+
inner_in_mit_sot = from_carry_storage(
173+
carry["mit_sot"], current_step, input_taps["mit_sot"]
174+
)
175+
inner_in_sit_sot = from_carry_storage(
176+
carry["sit_sot"], current_step, input_taps["sit_sot"]
177+
)
178+
inner_in_shared = carry.get("shared", [])
179+
inner_in_non_sequences = carry.get("non_sequences", [])
180+
181+
return sum(
182+
[
183+
inner_in_seqs,
184+
inner_in_mit_sot,
185+
inner_in_sit_sot,
186+
inner_in_shared,
187+
inner_in_non_sequences,
188+
],
189+
[],
190+
)
191+
192+
def carry_from_inner_outputs(inner_outputs):
193+
step = carry["step"]
194+
new_carry = {
195+
"mit_sot": [],
196+
"sit_sot": [],
197+
"nit-sot": [],
198+
"shared": [],
199+
"step": step + 1,
200+
"sequences": carry["sequences"],
201+
"non_sequences": carry["non_sequences"],
202+
}
203+
204+
if "shared" in inner_output_idx:
205+
shared_inner_outputs = [
206+
inner_outputs[idx] for idx in inner_output_idx["shared"]
207+
]
208+
new_carry["shared"] = shared_inner_outputs
209+
210+
if "mit_sot" in inner_output_idx:
211+
mit_sot_inner_outputs = [
212+
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
213+
]
214+
new_carry["mit_sot"] = to_carry_storage(
215+
mit_sot_inner_outputs, carry["mit_sot"], step, input_taps["mit_sot"]
216+
)
217+
218+
if "sit_sot" in inner_output_idx:
219+
sit_sot_inner_outputs = [
220+
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
221+
]
222+
new_carry["sit_sot"] = to_carry_storage(
223+
sit_sot_inner_outputs, carry["sit_sot"], step, input_taps["sit_sot"]
224+
)
225+
if "nit_sot" in inner_output_idx:
226+
nit_sot_inner_outputs = [
227+
inner_outputs[idx] for idx in inner_output_idx["nit_sot"]
228+
]
229+
new_carry["nit_sot"] = to_carry_storage(
230+
nit_sot_inner_outputs, carry["nit_sot"], step, input_taps["nit_sot"]
231+
)
232+
233+
return new_carry
234+
235+
def cond_fn(carry):
236+
# The inner-function of `Scan` returns a boolean as the last
237+
# value. This needs to be included in `carry`.
238+
# TODO: Will it return `False` if the number of steps is exceeded?
239+
return carry["do_continue"]
240+
241+
def body_fn(carry):
242+
inner_inputs = inner_inputs_from_carry(carry)
243+
inner_outputs = scan_inner_fn(*inner_inputs)
244+
new_carry = carry_from_inner_outputs(inner_outputs)
245+
return new_carry
246+
247+
# TODO
248+
# The `Scan` implementation in the C backend will execute the
249+
# function once before checking the termination condition, while
250+
# `jax.lax.while_loop` checks the condition first. We thus need to call
251+
# `body_fn` once before calling `jax.lax.while_loop`. This allows us,
252+
# along with `n_steps`, to build the storage array for the `nit-sot`s
253+
# since there is no way to know their shape and dtype before executing
254+
# the function.
255+
carry = body_fn(init_carry)
256+
carry = jax.lax.while_loop(body_fn, cond_fn, carry)
257+
258+
# TODO: Post-process the storage arrays
259+
outer_outputs = carry
260+
261+
return outer_outputs
262+
263+
return while_loop
46264

47265

48266
def make_jax_scan_fn(
@@ -58,7 +276,8 @@ def make_jax_scan_fn(
58276
stacked to the previous outputs. We use this to our advantage to build
59277
`Scan` outputs without having to post-process the storage arrays.
60278
61-
The JAX scan function needs to perform the following operations:
279+
The JAX `scan` function needs to perform the following operations:
280+
62281
1. Extract the inner-inputs;
63282
2. Build the initial `carry` and `sequence` values;
64283
3. Inside the loop:
@@ -151,7 +370,6 @@ def scan(*outer_inputs):
151370
outer_in = parse_outer_inputs(outer_inputs)
152371
n_steps, sequences, init_carry = build_jax_scan_inputs(outer_in)
153372
inner_output_idx = build_inner_outputs_map(outer_in)
154-
155373
def scan_inner_in_args(carry, x):
156374
"""Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
157375
@@ -265,11 +483,11 @@ def body_fn(carry, x):
265483
)
266484

267485
shared_output = tuple(last_carry["shared"])
268-
results = results + shared_output
486+
outer_outputs = results + shared_output
269487

270-
if len(results) == 1:
271-
return results[0]
488+
if len(outer_outputs) == 1:
489+
return outer_outputs[0]
272490

273-
return results
491+
return outer_outputs
274492

275493
return scan

tests/link/jax/test_scan.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from aesara.link.jax.linker import JAXLinker
1212
from aesara.scan.basic import scan
1313
from aesara.scan.op import Scan
14+
from aesara.scan.utils import until
1415
from aesara.tensor.math import gammaln, log
1516
from aesara.tensor.random.utils import RandomStream
1617
from aesara.tensor.type import ivector, lscalar, scalar
@@ -24,6 +25,15 @@
2425
jax_mode = Mode(JAXLinker(), opts)
2526

2627

28+
def test_while():
29+
res, updates = scan(
30+
fn=lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
31+
outputs_info=[{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
32+
n_steps=3,
33+
)
34+
jax_fn = function((), res, updates=updates, mode="JAX")
35+
36+
2737
def test_sit_sot():
2838
a_at = at.scalar("a", dtype="floatX")
2939

@@ -87,27 +97,27 @@ def test_mit_sot_2():
8797
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
8898
[
8999
# sequences
90-
(
91-
lambda a_t: 2 * a_t,
92-
[at.dvector("a")],
93-
[{}],
94-
[],
95-
None,
96-
[np.arange(10)],
97-
None,
98-
lambda op: op.info.n_seqs > 0,
99-
),
100-
# nit-sot
101-
(
102-
lambda: at.as_tensor(2.0),
103-
[],
104-
[{}],
105-
[],
106-
3,
107-
[],
108-
None,
109-
lambda op: op.info.n_nit_sot > 0,
110-
),
100+
# (
101+
# lambda a_t: 2 * a_t,
102+
# [at.dvector("a")],
103+
# [{}],
104+
# [],
105+
# None,
106+
# [np.arange(10)],
107+
# None,
108+
# lambda op: op.info.n_seqs > 0,
109+
# ),
110+
# # nit-sot
111+
# (
112+
# lambda: at.as_tensor(2.0),
113+
# [],
114+
# [{}],
115+
# [],
116+
# 3,
117+
# [],
118+
# None,
119+
# lambda op: op.info.n_nit_sot > 0,
120+
# ),
111121
# nit-sot, non_seq
112122
(
113123
lambda c: at.as_tensor(2.0) * c,

0 commit comments

Comments
 (0)