Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 2232d76

Browse files
committed
Support while loops in the JAX dispatcher
1 parent b78a011 commit 2232d76

File tree

5 files changed

+351
-34
lines changed

5 files changed

+351
-34
lines changed

aesara/link/jax/dispatch/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def jax_funcify_FunctionGraph(
4646
fgraph_name="jax_funcified_fgraph",
4747
**kwargs,
4848
):
49+
4950
return fgraph_to_python(
5051
fgraph,
5152
jax_funcify,

aesara/link/jax/dispatch/scan.py

Lines changed: 259 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Callable, Dict, List
33

44
import jax
5+
import jax.numpy as jnp
56

67
from aesara.link.jax.dispatch.basic import jax_funcify
78
from aesara.scan.op import Scan
@@ -14,6 +15,7 @@ def jax_funcify_Scan(op, node, **kwargs):
1415
input_taps = {
1516
"mit_sot": op.info.mit_sot_in_slices,
1617
"sit_sot": op.info.sit_sot_in_slices,
18+
"nit_sot": op.info.sit_sot_in_slices,
1719
}
1820

1921
# Outer-inputs are the inputs to the `Scan` apply node, built from the
@@ -36,13 +38,258 @@ def parse_outer_inputs(outer_inputs):
3638
return outer_in
3739

3840
if op.info.as_while:
39-
raise NotImplementedError("While loops are not supported in the JAX backend.")
41+
# We can only compile a `Scan` node that acts as a `while` loop to JAX
42+
# if only the last computed value is ever used in the outer function.
43+
# TODO: Determine if that's the case
44+
# TODO: Rewrite the graph if that's the case
45+
# TODO: Implement a simple `while` loop that returns the last step
46+
return make_jax_while_fn(scan_inner_fn, parse_outer_inputs, input_taps)
4047
else:
41-
return make_jax_scan_fn(
42-
scan_inner_fn,
43-
parse_outer_inputs,
44-
input_taps,
45-
)
48+
return make_jax_scan_fn(scan_inner_fn, parse_outer_inputs, input_taps)
49+
50+
51+
def make_jax_while_fn(
52+
scan_inner_fn: Callable,
53+
parse_outer_inputs: Callable[[TensorVariable], Dict[str, List[TensorVariable]]],
54+
input_taps: Dict,
55+
):
56+
"""Create a `jax.lax.while_loop` function to perform `Scan` computations when it
57+
is used as while loop.
58+
59+
`jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
60+
must return a value of the same type (Pytree structure, shape and dtype of
61+
the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
62+
current value and returns a boolean that indicates whether to keep iterating
63+
or not.
64+
65+
The JAX `while_loop` needs to perform the following operations:
66+
67+
1. Extract the inner-inputs;
68+
2. Build the initial carry value;
69+
3. Inside the loop:
70+
1. `carry` -> inner-inputs;
71+
2. inner-outputs -> `carry`
72+
4. Post-process the `carry` storage and return outputs
73+
"""
74+
75+
def build_while_carry(outer_in):
76+
"""Build the inputs to `jax.lax.scan` from the outer-inputs."""
77+
init_carry = {
78+
"mit_sot": [],
79+
"mit_sot_storage": outer_in["mit_sot"],
80+
"sit_sot": [],
81+
"sit_sot_storage": outer_in["sit_sot"],
82+
"shared": outer_in["shared"],
83+
"sequences": outer_in["sequences"],
84+
"non_sequences": outer_in["non_sequences"],
85+
}
86+
init_carry["step"] = 0
87+
init_carry["do_stop"] = False
88+
return init_carry
89+
90+
def build_inner_outputs_map(outer_in):
91+
"""Map the inner-output variables to their position in the tuple returned by the inner function.
92+
93+
TODO: Copied from the scan builder
94+
95+
Inner-outputs are ordered as follow:
96+
- mit-mot-outputs
97+
- mit-sot-outputs
98+
- sit-sot-outputs
99+
- nit-sots (no carry)
100+
- shared-outputs
101+
[+ while-condition]
102+
103+
"""
104+
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
105+
106+
offset = 0
107+
inner_output_idx = defaultdict(list)
108+
for name in inner_outputs_names:
109+
num_outputs = len(outer_in[name])
110+
for i in range(num_outputs):
111+
inner_output_idx[name].append(offset + i)
112+
offset += num_outputs
113+
114+
return inner_output_idx
115+
116+
def from_carry_storage(carry, step, input_taps):
117+
"""Fetch the inner inputs from the values stored in the carry array.
118+
119+
`Scan` passes storage arrays as inputs, which are then read from and
120+
updated in the loop body. At each step we need to read from this array
121+
the inputs that will be passed to the inner function.
122+
123+
This mechanism is necessary because we handle multiple-input taps within
124+
the `scan` instead of letting users manage the memory in the use cases
125+
where this is necessary.
126+
127+
TODO: Copied from the scan builder
128+
129+
"""
130+
131+
def fetch(carry, step, offset):
132+
return carry[step + offset]
133+
134+
inner_inputs = []
135+
for taps, carry_element in zip(input_taps, carry):
136+
storage_size = -min(taps)
137+
offsets = [storage_size + tap for tap in taps]
138+
inner_inputs.append(
139+
[fetch(carry_element, step, offset) for offset in offsets]
140+
)
141+
142+
return sum(inner_inputs, [])
143+
144+
def to_carry_storage(inner_outputs, carry, step, input_taps):
145+
"""Create the new carry array from the inner output
146+
147+
`Scan` passes storage arrays as inputs, which are then read from and
148+
updated in the loop body. At each step we need to update this array
149+
with the outputs of the inner function
150+
151+
TODO: Copied from the scan builder
152+
153+
"""
154+
new_carry_element = []
155+
for taps, carry_element, output in zip(input_taps, carry, inner_outputs):
156+
new_carry_element.append(
157+
[carry_element.at[step - tap].set(output) for tap in taps]
158+
)
159+
160+
return sum(new_carry_element, [])
161+
162+
def while_loop(*outer_inputs):
163+
164+
outer_in = parse_outer_inputs(outer_inputs)
165+
init_carry = build_while_carry(outer_in)
166+
inner_output_idx = build_inner_outputs_map(outer_in)
167+
168+
def inner_inputs_from_carry(carry):
169+
"""Get inner-inputs from the arguments passed to the `jax.lax.while_loop` body function.
170+
171+
Inner-inputs are ordered as follows:
172+
- sequences
173+
- mit-mot inputs
174+
- mit-sot inputs
175+
- sit-sot inputs
176+
- shared-inputs
177+
- non-sequences
178+
179+
"""
180+
current_step = carry["step"]
181+
182+
inner_in_mit_sot = from_carry_storage(
183+
carry["mit_sot_storage"], current_step, input_taps["mit_sot"]
184+
)
185+
inner_in_sit_sot = from_carry_storage(
186+
carry["sit_sot_storage"], current_step, input_taps["sit_sot"]
187+
)
188+
inner_in_shared = carry.get("shared", [])
189+
inner_in_non_sequences = carry.get("non_sequences", [])
190+
191+
return sum(
192+
[
193+
inner_in_mit_sot,
194+
inner_in_sit_sot,
195+
inner_in_shared,
196+
inner_in_non_sequences,
197+
],
198+
[],
199+
)
200+
201+
def carry_from_inner_outputs(carry, inner_outputs):
202+
step = carry["step"]
203+
new_carry = {
204+
"mit_sot": [],
205+
"sit_sot": [],
206+
"sit_sot_storage": [],
207+
"nit_sot": [],
208+
"mit_sot_storage": [],
209+
"shared": [],
210+
"step": step + 1,
211+
"sequences": carry["sequences"],
212+
"non_sequences": carry["non_sequences"],
213+
"do_stop": inner_outputs[-1],
214+
}
215+
216+
if "shared" in inner_output_idx:
217+
shared_inner_outputs = [
218+
inner_outputs[idx] for idx in inner_output_idx["shared"]
219+
]
220+
new_carry["shared"] = shared_inner_outputs
221+
222+
if "mit_sot" in inner_output_idx:
223+
mit_sot_inner_outputs = [
224+
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
225+
]
226+
new_carry["mit_sot"] = mit_sot_inner_outputs
227+
new_carry["mit_sot_storage"] = to_carry_storage(
228+
mit_sot_inner_outputs,
229+
carry["mit_sot_storage"],
230+
step,
231+
input_taps["mit_sot"],
232+
)
233+
234+
if "sit_sot" in inner_output_idx:
235+
sit_sot_inner_outputs = [
236+
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
237+
]
238+
new_carry["sit_sot"] = sit_sot_inner_outputs
239+
new_carry["sit_sot_storage"] = to_carry_storage(
240+
sit_sot_inner_outputs,
241+
carry["sit_sot_storage"],
242+
step,
243+
input_taps["sit_sot"],
244+
)
245+
246+
if "nit_sot" in inner_output_idx:
247+
nit_sot_inner_outputs = [
248+
inner_outputs[idx] for idx in inner_output_idx["nit_sot"]
249+
]
250+
new_carry["nit_sot"] = nit_sot_inner_outputs
251+
252+
return new_carry
253+
254+
def cond_fn(carry):
255+
# The inner-function of `Scan` returns a boolean as the last
256+
# value. This needs to be included in `carry`.
257+
# TODO: Will it return `False` if the number of steps is exceeded?
258+
return ~carry["do_stop"]
259+
260+
def body_fn(carry):
261+
inner_inputs = inner_inputs_from_carry(carry)
262+
inner_outputs = scan_inner_fn(*inner_inputs)
263+
new_carry = carry_from_inner_outputs(carry, inner_outputs)
264+
return new_carry
265+
266+
# The `Scan` implementation in the C backend will execute the
267+
# function once before checking the termination condition, while
268+
# `jax.lax.while_loop` checks the condition first. We thus need to call
269+
# `body_fn` once before calling `jax.lax.while_loop`. This allows us,
270+
# along with `n_steps`, to build the storage array for the `nit-sot`s
271+
# since there is no way to know their shape and dtype before executing
272+
# the function.
273+
inner_inputs = inner_inputs_from_carry(init_carry)
274+
inner_outputs = scan_inner_fn(*inner_inputs)
275+
carry = carry_from_inner_outputs(init_carry, inner_outputs)
276+
carry = jax.lax.while_loop(cond_fn, body_fn, carry)
277+
278+
# Post-process the storage arrays
279+
# We make sure that the outputs are not scalars in case an array
280+
# is expected downstream since `Scan` is supposed to always return arrays
281+
carry["sit_sot"] = [jnp.atleast_1d(element) for element in carry["sit_sot"]]
282+
carry["mit_sot"] = [jnp.atleast_1d(element) for element in carry["mit_sot"]]
283+
carry["nit_not"] = [jnp.atleast_1d(element) for element in carry["nit_sot"]]
284+
285+
outer_outputs = ["mit_sot", "sit_sot", "nit_sot", "shared"]
286+
results = sum([carry[output] for output in outer_outputs], [])
287+
if len(results) == 1:
288+
return results[0]
289+
else:
290+
return results
291+
292+
return while_loop
46293

47294

48295
def make_jax_scan_fn(
@@ -58,7 +305,8 @@ def make_jax_scan_fn(
58305
stacked to the previous outputs. We use this to our advantage to build
59306
`Scan` outputs without having to post-process the storage arrays.
60307
61-
The JAX scan function needs to perform the following operations:
308+
The JAX `scan` function needs to perform the following operations:
309+
62310
1. Extract the inner-inputs;
63311
2. Build the initial `carry` and `sequence` values;
64312
3. Inside the loop:
@@ -265,11 +513,11 @@ def body_fn(carry, x):
265513
)
266514

267515
shared_output = tuple(last_carry["shared"])
268-
results = results + shared_output
516+
outer_outputs = results + shared_output
269517

270-
if len(results) == 1:
271-
return results[0]
518+
if len(outer_outputs) == 1:
519+
return outer_outputs[0]
272520

273-
return results
521+
return outer_outputs
274522

275523
return scan

aesara/link/jax/dispatch/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def shape(x):
3333

3434

3535
@jax_funcify.register(Shape_i)
36-
def jax_funcify_Shape_i(op, **kwargs):
36+
def jax_funcify_Shape_i(op, node, **kwargs):
3737
i = op.i
3838

3939
def shape_i(x):

aesara/link/jax/linker.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from aesara.compile.sharedvalue import SharedVariable, shared
66
from aesara.graph.basic import Constant
7+
from aesara.graph.rewriting.basic import WalkingGraphRewriter, node_rewriter
78
from aesara.link.basic import JITLinker
89

910

@@ -12,7 +13,10 @@ class JAXLinker(JITLinker):
1213

1314
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1415
from aesara.link.jax.dispatch import jax_funcify
16+
from aesara.scan.op import Scan
1517
from aesara.tensor.random.type import RandomType
18+
from aesara.tensor.shape import Shape_i
19+
from aesara.tensor.subtensor import Subtensor
1620

1721
shared_rng_inputs = [
1822
inp
@@ -49,6 +53,37 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
4953
fgraph.inputs.index(old_inp), reason="JAXLinker.fgraph_convert"
5054
)
5155

56+
@node_rewriter([Scan])
57+
def check_while_returns_last_output(fgraph, node):
58+
op = node.op
59+
if not op.info.as_while:
60+
return False
61+
62+
# Count the number of outputs of the outer function. We ignore
63+
# `shared` variables since they are not not accumulated and not
64+
# returned to the user.
65+
num_outer_outputs = (
66+
op.info.n_mit_mot
67+
+ op.info.n_mit_sot
68+
+ op.info.n_sit_sot
69+
+ op.info.n_nit_sot
70+
)
71+
for out in node.outputs[:num_outer_outputs]:
72+
for client, _ in fgraph.clients[out]:
73+
if isinstance(client, str):
74+
raise NotImplementedError()
75+
elif isinstance(client.op, Subtensor):
76+
idx_list = client.op.idx_list
77+
if isinstance(idx_list[0], slice):
78+
raise NotImplementedError()
79+
elif not isinstance(client.op, Shape_i):
80+
raise NotImplementedError()
81+
82+
return False
83+
84+
jax_opt = WalkingGraphRewriter(check_while_returns_last_output)
85+
jax_opt.rewrite(fgraph)
86+
5287
return jax_funcify(
5388
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
5489
)

0 commit comments

Comments
 (0)