Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit d1326b8

Browse files
committed
Rewrite minimal Scan dispatch for JAX
Passes the first `xit_xot_types` test taken from the Numba test suite.
1 parent 1ab4c69 commit d1326b8

File tree

2 files changed

+135
-146
lines changed

2 files changed

+135
-146
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 66 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,159 +1,79 @@
11
import jax
2-
import jax.numpy as jnp
32

4-
from aesara.graph.fg import FunctionGraph
53
from aesara.link.jax.dispatch.basic import jax_funcify
64
from aesara.scan.op import Scan
7-
from aesara.scan.utils import ScanArgs
85

96

107
@jax_funcify.register(Scan)
11-
def jax_funcify_Scan(op, **kwargs):
12-
inner_fg = FunctionGraph(op.inputs, op.outputs)
13-
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
8+
def jax_funcify_Scan(op, node, **kwargs):
9+
scan_inner_func = jax_funcify(op.fgraph)
1410

1511
def scan(*outer_inputs):
16-
scan_args = ScanArgs(
17-
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
18-
)
19-
20-
# `outer_inputs` is a list with the following composite form:
21-
# [n_steps]
22-
# + outer_in_seqs
23-
# + outer_in_mit_mot
24-
# + outer_in_mit_sot
25-
# + outer_in_sit_sot
26-
# + outer_in_shared
27-
# + outer_in_nit_sot
28-
# + outer_in_non_seqs
29-
n_steps = scan_args.n_steps
30-
seqs = scan_args.outer_in_seqs
31-
32-
# TODO: mit_mots
33-
mit_mot_in_slices = []
34-
35-
mit_sot_in_slices = []
36-
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
37-
neg_taps = [abs(t) for t in tap if t < 0]
38-
pos_taps = [abs(t) for t in tap if t > 0]
39-
max_neg = max(neg_taps) if neg_taps else 0
40-
max_pos = max(pos_taps) if pos_taps else 0
41-
init_slice = seq[: max_neg + max_pos]
42-
mit_sot_in_slices.append(init_slice)
43-
44-
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
45-
46-
init_carry = (
47-
mit_mot_in_slices,
48-
mit_sot_in_slices,
49-
sit_sot_in_slices,
50-
scan_args.outer_in_shared,
51-
scan_args.outer_in_non_seqs,
52-
)
53-
54-
def jax_args_to_inner_scan(op, carry, x):
55-
# `carry` contains all inner-output taps, non_seqs, and shared
56-
# terms
57-
(
58-
inner_in_mit_mot,
59-
inner_in_mit_sot,
60-
inner_in_sit_sot,
61-
inner_in_shared,
62-
inner_in_non_seqs,
63-
) = carry
64-
65-
# `x` contains the in_seqs
66-
inner_in_seqs = x
6712

68-
# `inner_scan_inputs` is a list with the following composite form:
69-
# inner_in_seqs
70-
# + sum(inner_in_mit_mot, [])
71-
# + sum(inner_in_mit_sot, [])
72-
# + inner_in_sit_sot
73-
# + inner_in_shared
74-
# + inner_in_non_seqs
75-
inner_in_mit_sot_flatten = []
76-
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
77-
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
78-
79-
inner_scan_inputs = sum(
80-
[
81-
inner_in_seqs,
82-
inner_in_mit_mot,
83-
inner_in_mit_sot_flatten,
84-
inner_in_sit_sot,
85-
inner_in_shared,
86-
inner_in_non_seqs,
87-
],
88-
[],
89-
)
90-
91-
return inner_scan_inputs
92-
93-
def inner_scan_outs_to_jax_outs(
94-
op,
95-
old_carry,
96-
inner_scan_outs,
97-
):
98-
(
99-
inner_in_mit_mot,
100-
inner_in_mit_sot,
101-
inner_in_sit_sot,
102-
inner_in_shared,
103-
inner_in_non_seqs,
104-
) = old_carry
105-
106-
def update_mit_sot(mit_sot, new_val):
107-
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
108-
109-
inner_out_mit_sot = [
110-
update_mit_sot(mit_sot, new_val)
111-
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
112-
]
113-
114-
# This should contain all inner-output taps, non_seqs, and shared
115-
# terms
116-
if not inner_in_sit_sot:
117-
inner_out_sit_sot = []
13+
n_steps = outer_inputs[0]
14+
outer_in_seqs = list(op.outer_seqs(outer_inputs))
15+
outer_in_mit_mot = list(op.outer_mitmot(outer_inputs))
16+
outer_in_mit_sot = list(op.outer_mitsot(outer_inputs))
17+
outer_in_sit_sot = list(op.outer_sitsot(outer_inputs))
18+
outer_in_nit_sot = list(op.outer_nitsot(outer_inputs))
19+
outer_in_shared = list(op.outer_shared(outer_inputs))
20+
outer_in_non_seqs = list(op.outer_non_seqs(outer_inputs))
21+
if len(outer_in_mit_mot):
22+
raise NotImplementedError("mit-mot not supported")
23+
if len(outer_in_mit_sot):
24+
raise NotImplementedError("mit-sot not supported")
25+
if len(outer_in_sit_sot):
26+
raise NotImplementedError("sit-sot not supported")
27+
if len(outer_in_shared):
28+
raise NotImplementedError("shared variables not supported")
29+
if len(outer_in_non_seqs):
30+
raise NotImplementedError("non sequence are not supported")
31+
32+
init_carry = outer_in_nit_sot
33+
sequences = outer_in_seqs
34+
35+
def scan_inner_in_args(carry, x, is_dummy_sit_sot=True):
36+
"""Create an inner-input expression.
37+
38+
Inner-inputs are ordered as follows:
39+
- sequences
40+
- mit-mot inputs
41+
- mit-sot inputs
42+
- sit-sot inputs
43+
- shared-inputs
44+
- non-sequences
45+
"""
46+
47+
inner_in_seqs = x
48+
if is_dummy_sit_sot:
49+
inner_in_sit_sot = []
11850
else:
119-
inner_out_sit_sot = inner_scan_outs
120-
new_carry = (
121-
inner_in_mit_mot,
122-
inner_out_mit_sot,
123-
inner_out_sit_sot,
124-
inner_in_shared,
125-
inner_in_non_seqs,
126-
)
127-
128-
return new_carry
129-
130-
def jax_inner_func(carry, x):
131-
inner_args = jax_args_to_inner_scan(op, carry, x)
132-
inner_scan_outs = list(jax_at_inner_func(*inner_args))
133-
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
134-
return new_carry, inner_scan_outs
135-
136-
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
137-
138-
# We need to prepend the initial values so that the JAX output will
139-
# match the raw `Scan` `Op` output and, thus, work with a downstream
140-
# `Subtensor` `Op` introduced by the `scan` helper function.
141-
def append_scan_out(scan_in_part, scan_out_part):
142-
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
143-
144-
if scan_args.outer_in_mit_sot:
145-
scan_out_final = [
146-
append_scan_out(init, out)
147-
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
148-
]
149-
elif scan_args.outer_in_sit_sot:
150-
scan_out_final = [
151-
append_scan_out(init, out)
152-
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
153-
]
154-
155-
if len(scan_out_final) == 1:
156-
scan_out_final = scan_out_final[0]
157-
return scan_out_final
51+
inner_in_sit_sot = carry
52+
return sum([inner_in_seqs, inner_in_sit_sot], [])
53+
54+
def scan_new_carry(inner_outputs):
55+
"""Create a new carry expression
56+
57+
Inner-outputs are ordered as follow:
58+
- mit-mot-outputs
59+
- mit-sot-outputs
60+
- sit-sot-outputs
61+
- nit-sots
62+
- shared-outputs
63+
[+ while-condition]
64+
65+
"""
66+
carry = list(inner_outputs)
67+
return carry
68+
69+
def body_fn(carry, x):
70+
inner_in_args = scan_inner_in_args(carry, x)
71+
inner_outputs = scan_inner_func(*inner_in_args)
72+
carry = scan_new_carry(inner_outputs)
73+
return carry, *inner_outputs
74+
75+
_, results = jax.lax.scan(body_fn, init_carry, sequences, length=n_steps)
76+
77+
return results
15878

15979
return scan

tests/link/jax/test_scan.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,85 @@
33
from packaging.version import parse as version_parse
44

55
import aesara.tensor as at
6+
from aesara import function
7+
from aesara.compile.mode import Mode
68
from aesara.configdefaults import config
79
from aesara.graph.fg import FunctionGraph
10+
from aesara.graph.rewriting.db import RewriteDatabaseQuery
11+
from aesara.link.jax.linker import JAXLinker
812
from aesara.scan.basic import scan
13+
from aesara.scan.op import Scan
914
from aesara.tensor.math import gammaln, log
1015
from aesara.tensor.type import ivector, lscalar, scalar
1116
from tests.link.jax.test_basic import compare_jax_and_py
1217

1318

1419
jax = pytest.importorskip("jax")
1520

21+
# Disable all optimizations
22+
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
23+
jax_mode = Mode(JAXLinker(), opts)
24+
25+
26+
@pytest.mark.parametrize(
27+
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
28+
[
29+
# sequences
30+
(
31+
lambda a_t: 2 * a_t,
32+
[at.dvector("a")],
33+
[{}],
34+
[],
35+
None,
36+
[np.arange(10)],
37+
None,
38+
lambda op: op.info.n_seqs > 0,
39+
),
40+
],
41+
)
42+
def test_xit_xot_types(
43+
fn,
44+
sequences,
45+
outputs_info,
46+
non_sequences,
47+
n_steps,
48+
input_vals,
49+
output_vals,
50+
op_check,
51+
):
52+
"""Test basic xit-xot configurations."""
53+
res, updates = scan(
54+
fn,
55+
sequences=sequences,
56+
outputs_info=outputs_info,
57+
non_sequences=non_sequences,
58+
n_steps=n_steps,
59+
strict=True,
60+
mode=Mode(linker="py", optimizer=None),
61+
)
62+
63+
if not isinstance(res, list):
64+
res = [res]
65+
66+
# Get rid of any `Subtensor` indexing on the `Scan` outputs
67+
res = [r.owner.inputs[0] if not isinstance(r.owner.op, Scan) else r for r in res]
68+
69+
scan_op = res[0].owner.op
70+
assert isinstance(scan_op, Scan)
71+
72+
_ = op_check(scan_op)
73+
74+
if output_vals is None:
75+
compare_jax_and_py(
76+
((sequences + non_sequences), res), input_vals, updates=updates
77+
)
78+
else:
79+
jax_fn = function(
80+
(sequences + non_sequences), res, mode=jax_mode, updates=updates
81+
)
82+
res_vals = jax_fn(*input_vals)
83+
assert np.allclose(res_vals, output_vals)
84+
1685

1786
@pytest.mark.xfail(
1887
version_parse(jax.__version__) >= version_parse("0.2.12"),

0 commit comments

Comments
 (0)