Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit b78a011

Browse files
committed
Refactor the JAX Scan dispatcher
1 parent 2fbcea5 commit b78a011

File tree

1 file changed

+141
-96
lines changed

1 file changed

+141
-96
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 141 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,25 @@
11
from collections import defaultdict
2+
from typing import Callable, Dict, List
23

34
import jax
45

56
from aesara.link.jax.dispatch.basic import jax_funcify
67
from aesara.scan.op import Scan
8+
from aesara.tensor.var import TensorVariable
79

810

911
@jax_funcify.register(Scan)
1012
def jax_funcify_Scan(op, node, **kwargs):
11-
scan_inner_func = jax_funcify(op.fgraph)
12-
if op.info.as_while:
13-
raise NotImplementedError("While loops are not supported in the JAX backend.")
14-
15-
sit_sot_input_taps = op.info.sit_sot_in_slices
16-
mit_sot_input_taps = op.info.mit_sot_in_slices
17-
18-
# Construct `scan_inner_func`'s arguments from the carry value and sequence
19-
# element passed to `body_fn`.
20-
#
21-
# We need to index the storage arrays carried by `jax.lax.scan` for arguments with
22-
# inputs taps.
23-
def index_carry_arrays(input_taps):
24-
"""Fetch the inner inputs from the values stored in the carry array"""
25-
# TODO: Check and refactor this
26-
storage_size = -min(input_taps)
27-
offsets = [storage_size + tap for tap in input_taps]
28-
29-
def to_inner_inputs(step, carry):
30-
return [carry[step + offset] for offset in offsets]
31-
32-
return to_inner_inputs
33-
34-
sit_sot_from_carry = [index_carry_arrays(tap) for tap in sit_sot_input_taps]
35-
mit_sot_from_carry = [index_carry_arrays(tap) for tap in mit_sot_input_taps]
36-
37-
# Construct the new carry values from the outputs of `scan_inner_func`
38-
def inner_outputs_to_carry(input_taps):
39-
"""Create the new carry array from the inner output"""
40-
# TODO: Check and refactor this
41-
def to_new_carry(step, carry, inner_outputs):
42-
return [carry.at[step - tap].set(inner_outputs) for tap in input_taps]
43-
44-
return to_new_carry
45-
46-
sit_sot_to_carry = [inner_outputs_to_carry(tap) for tap in sit_sot_input_taps]
47-
mit_sot_to_carry = [inner_outputs_to_carry(tap) for tap in mit_sot_input_taps]
48-
49-
def scan(*outer_inputs):
50-
51-
# Inputs to `aesara.scan`
13+
scan_inner_fn = jax_funcify(op.fgraph)
14+
input_taps = {
15+
"mit_sot": op.info.mit_sot_in_slices,
16+
"sit_sot": op.info.sit_sot_in_slices,
17+
}
18+
19+
# Outer-inputs are the inputs to the `Scan` apply node, built from the
20+
# the variables provided by the caller to the `scan` function at construction
21+
# time.
22+
def parse_outer_inputs(outer_inputs):
5223
outer_in = {
5324
"n_steps": outer_inputs[0],
5425
"sequences": list(op.outer_seqs(outer_inputs)),
@@ -62,22 +33,125 @@ def scan(*outer_inputs):
6233
if len(outer_in["mit_mot"]) > 0:
6334
raise NotImplementedError("mit-mot not supported")
6435

65-
# Inputs to `jax.lax.scan`
36+
return outer_in
37+
38+
if op.info.as_while:
39+
raise NotImplementedError("While loops are not supported in the JAX backend.")
40+
else:
41+
return make_jax_scan_fn(
42+
scan_inner_fn,
43+
parse_outer_inputs,
44+
input_taps,
45+
)
46+
47+
48+
def make_jax_scan_fn(
49+
scan_inner_fn: Callable,
50+
parse_outer_inputs: Callable[[TensorVariable], Dict[str, List[TensorVariable]]],
51+
input_taps: Dict,
52+
):
53+
"""Create a `jax.lax.scan` function to perform `Scan` computations.
54+
55+
`jax.lax.scan` takes an initial `carry` value and a sequence it scans over,
56+
or a number of iterations. The first output of the loop body function, the
57+
`carry`, is carried over to the next iteration. The second, the `output`, is
58+
stacked to the previous outputs. We use this to our advantage to build
59+
`Scan` outputs without having to post-process the storage arrays.
60+
61+
The JAX scan function needs to perform the following operations:
62+
1. Extract the inner-inputs;
63+
2. Build the initial `carry` and `sequence` values;
64+
3. Inside the loop:
65+
1. `carry` + sequence elements -> inner-inputs
66+
2. inner-outputs -> `carry`
67+
3. inner-outputs -> `output`
68+
4. Append the last `shared` value to the stacked `output`s
69+
70+
"""
71+
72+
def build_jax_scan_inputs(outer_in: Dict):
73+
"""Build the inputs to `jax.lax.scan` from the outer-inputs."""
6674
n_steps = outer_in["n_steps"]
6775
sequences = outer_in["sequences"]
68-
non_sequences = outer_in["non_sequences"]
69-
init_carry = {name: outer_in[name] for name in ["mit_sot", "sit_sot", "shared"]}
76+
init_carry = {
77+
name: outer_in[name]
78+
for name in ["mit_sot", "sit_sot", "shared", "non_sequences"]
79+
}
7080
init_carry["step"] = 0
81+
return n_steps, sequences, init_carry
82+
83+
def build_inner_outputs_map(outer_in):
84+
"""Map the inner-output variables to their position in the tuple returned by the inner function.
85+
86+
Inner-outputs are ordered as follow:
87+
- mit-mot-outputs
88+
- mit-sot-outputs
89+
- sit-sot-outputs
90+
- nit-sots (no carry)
91+
- shared-outputs
92+
[+ while-condition]
93+
94+
"""
95+
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
7196

72-
# Map to retrieve the inner-outputs
7397
offset = 0
7498
inner_output_idx = defaultdict(list)
75-
for name in ["mit_sot", "sit_sot", "nit_sot", "shared"]:
99+
for name in inner_outputs_names:
76100
num_outputs = len(outer_in[name])
77101
for i in range(num_outputs):
78102
inner_output_idx[name].append(offset + i)
79103
offset += num_outputs
80104

105+
return inner_output_idx
106+
107+
def from_carry_storage(carry, step, input_taps):
108+
"""Fetch the inner inputs from the values stored in the carry array.
109+
110+
`Scan` passes storage arrays as inputs, which are then read from and
111+
updated in the loop body. At each step we need to read from this array
112+
the inputs that will be passed to the inner function.
113+
114+
This mechanism is necessary because we handle multiple-input taps within
115+
the `scan` instead of letting users manage the memory in the use cases
116+
where this is necessary.
117+
118+
"""
119+
120+
def fetch(carry, step, offset):
121+
return carry[step + offset]
122+
123+
inner_inputs = []
124+
for taps, carry_element in zip(input_taps, carry):
125+
storage_size = -min(taps)
126+
offsets = [storage_size + tap for tap in taps]
127+
inner_inputs.append(
128+
[fetch(carry_element, step, offset) for offset in offsets]
129+
)
130+
131+
return sum(inner_inputs, [])
132+
133+
def to_carry_storage(inner_outputs, carry, step, input_taps):
134+
"""Create the new carry array from the inner output
135+
136+
`Scan` passes storage arrays as inputs, which are then read from and
137+
updated in the loop body. At each step we need to update this array
138+
with the outputs of the inner function
139+
140+
"""
141+
new_carry_element = []
142+
for taps, carry_element, output in zip(input_taps, carry, inner_outputs):
143+
new_carry_element.append(
144+
[carry_element.at[step - tap].set(output) for tap in taps]
145+
)
146+
147+
return sum(new_carry_element, [])
148+
149+
def scan(*outer_inputs):
150+
151+
outer_in = parse_outer_inputs(outer_inputs)
152+
n_steps, sequences, init_carry = build_jax_scan_inputs(outer_in)
153+
inner_output_idx = build_inner_outputs_map(outer_in)
154+
81155
def scan_inner_in_args(carry, x):
82156
"""Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
83157
@@ -90,54 +164,39 @@ def scan_inner_in_args(carry, x):
90164
- non-sequences
91165
92166
"""
93-
step = carry["step"]
167+
current_step = carry["step"]
94168

95169
inner_in_seqs = x
96-
inner_in_mit_sot = sum(
97-
[
98-
convert(step, carry_element)
99-
for convert, carry_element in zip(
100-
mit_sot_from_carry, carry["mit_sot"]
101-
)
102-
],
103-
[],
170+
inner_in_mit_sot = from_carry_storage(
171+
carry["mit_sot"], current_step, input_taps["mit_sot"]
104172
)
105-
inner_in_sit_sot = sum(
106-
[
107-
convert(step, carry_element)
108-
for convert, carry_element in zip(
109-
sit_sot_from_carry, carry["sit_sot"]
110-
)
111-
],
112-
[],
173+
inner_in_sit_sot = from_carry_storage(
174+
carry["sit_sot"], current_step, input_taps["sit_sot"]
113175
)
114176
inner_in_shared = carry.get("shared", [])
177+
inner_in_non_sequences = carry.get("non_sequences", [])
115178

116179
return sum(
117180
[
118181
inner_in_seqs,
119182
inner_in_mit_sot,
120183
inner_in_sit_sot,
121184
inner_in_shared,
122-
non_sequences,
185+
inner_in_non_sequences,
123186
],
124187
[],
125188
)
126189

127190
def scan_new_carry(carry, inner_outputs):
128-
"""Create a new carry value from the inner-outputs.
129-
130-
Inner-outputs are ordered as follow:
131-
- mit-mot-outputs
132-
- mit-sot-outputs
133-
- sit-sot-outputs
134-
- nit-sots (no carry)
135-
- shared-outputs
136-
[+ while-condition]
137-
138-
"""
191+
"""Create a new carry value from the values returned by the inner function (inner-outputs)."""
139192
step = carry["step"]
140-
new_carry = {"mit_sot": [], "sit_sot": [], "shared": []}
193+
new_carry = {
194+
"mit_sot": [],
195+
"sit_sot": [],
196+
"shared": [],
197+
"step": step + 1,
198+
"non_sequences": carry["non_sequences"],
199+
}
141200

142201
if "shared" in inner_output_idx:
143202
shared_inner_outputs = [
@@ -149,36 +208,22 @@ def scan_new_carry(carry, inner_outputs):
149208
mit_sot_inner_outputs = [
150209
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
151210
]
152-
new_carry["mit_sot"] = sum(
153-
[
154-
convert(step, carry_element, inner_outputs_element)
155-
for (convert, carry_element, inner_outputs_element) in zip(
156-
mit_sot_to_carry, carry["mit_sot"], mit_sot_inner_outputs
157-
)
158-
],
159-
[],
211+
new_carry["mit_sot"] = to_carry_storage(
212+
mit_sot_inner_outputs, carry["mit_sot"], step, input_taps["mit_sot"]
160213
)
161214

162215
if "sit_sot" in inner_output_idx:
163216
sit_sot_inner_outputs = [
164217
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
165218
]
166-
new_carry["sit_sot"] = sum(
167-
[
168-
convert(step, carry_element, inner_outputs_element)
169-
for (convert, carry_element, inner_outputs_element) in zip(
170-
sit_sot_to_carry, carry["sit_sot"], sit_sot_inner_outputs
171-
)
172-
],
173-
[],
219+
new_carry["sit_sot"] = to_carry_storage(
220+
sit_sot_inner_outputs, carry["sit_sot"], step, input_taps["sit_sot"]
174221
)
175222

176-
new_carry["step"] = carry["step"] + 1
177-
178223
return new_carry
179224

180225
def scan_new_outputs(inner_outputs):
181-
"""Create a new outer-output value from the inner-output value.
226+
"""Create a new outer-output value from the outputs of the inner function.
182227
183228
Outer-outputs are ordered as follows:
184229
- mit-mot-outputs
@@ -210,7 +255,7 @@ def scan_new_outputs(inner_outputs):
210255

211256
def body_fn(carry, x):
212257
inner_in_args = scan_inner_in_args(carry, x)
213-
inner_outputs = scan_inner_func(*inner_in_args)
258+
inner_outputs = scan_inner_fn(*inner_in_args)
214259
new_carry = scan_new_carry(carry, inner_outputs)
215260
outer_outputs = scan_new_outputs(inner_outputs)
216261
return new_carry, outer_outputs

0 commit comments

Comments
 (0)