Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 8df4879

Browse files
committed
Support shared inputs in JAX's Scan dispatch
1 parent 6a18241 commit 8df4879

File tree

2 files changed

+93
-19
lines changed

2 files changed

+93
-19
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,40 @@ def scan(*outer_inputs):
4848
outer_in_seqs = list(op.outer_seqs(outer_inputs))
4949
outer_in_mit_mot = list(op.outer_mitmot(outer_inputs))
5050
outer_in_mit_sot = list(op.outer_mitsot(outer_inputs))
51+
outer_in_nit_sot = list(op.outer_nitsot(outer_inputs))
5152
outer_in_sit_sot = list(op.outer_sitsot(outer_inputs))
5253
outer_in_shared = list(op.outer_shared(outer_inputs))
5354
outer_in_non_seqs = list(op.outer_non_seqs(outer_inputs))
5455
if len(outer_in_mit_mot):
5556
raise NotImplementedError("mit-mot not supported")
5657
if len(outer_in_mit_sot):
5758
raise NotImplementedError("mit-sot not supported")
58-
if len(outer_in_shared):
59-
raise NotImplementedError("shared variables not supported")
6059

61-
init_carry = outer_in_sit_sot
60+
# These are the outer-inputs
6261
sequences = outer_in_seqs
6362
non_sequences = outer_in_non_seqs
63+
init_carry = {}
64+
for name, outputs in [
65+
("sit_sot", outer_in_sit_sot),
66+
("shared", outer_in_shared),
67+
]:
68+
if len(outputs) > 0:
69+
init_carry[name] = outputs
70+
71+
# We keep track of the kind of inner_outputs and their number
72+
from collections import defaultdict
73+
74+
offset = 0
75+
inner_output_idx = defaultdict(list)
76+
for name, outputs in [
77+
("sit_sot", outer_in_sit_sot),
78+
("nit_sot", outer_in_nit_sot),
79+
("shared", outer_in_shared),
80+
]:
81+
if len(outputs) > 0:
82+
for i in range(len(outputs)):
83+
inner_output_idx[name].append(offset + i)
84+
offset += len(outputs)
6485

6586
def scan_inner_in_args(carry, x):
6687
"""Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
@@ -72,18 +93,22 @@ def scan_inner_in_args(carry, x):
7293
- sit-sot inputs
7394
- shared-inputs
7495
- non-sequences
96+
7597
"""
7698
inner_in_seqs = x
7799
inner_in_sit_sot = sum(
78100
[
79101
convert(carry_element)
80102
for convert in sit_sot_from_carry
81-
for carry_element in carry
103+
for carry_element in carry["sit_sot"]
82104
],
83105
[],
84106
)
107+
inner_in_shared = carry.get("shared", [])
85108

86-
return sum([inner_in_seqs, inner_in_sit_sot, non_sequences], [])
109+
return sum(
110+
[inner_in_seqs, inner_in_sit_sot, inner_in_shared, non_sequences], []
111+
)
87112

88113
def scan_new_carry(carry, inner_outputs):
89114
"""Create a new carry value from the inner-outputs.
@@ -97,27 +122,39 @@ def scan_new_carry(carry, inner_outputs):
97122
[+ while-condition]
98123
99124
"""
100-
carry = sum(
101-
[
102-
convert(carry_element, inner_outputs_element)
103-
for convert in sit_sot_to_carry
104-
for (carry_element, inner_outputs_element) in zip(
105-
carry, inner_outputs
106-
)
107-
],
108-
[],
109-
)
110-
return carry
125+
new_carry = {}
126+
127+
if "shared" in inner_output_idx:
128+
shared_inner_outputs = [
129+
inner_outputs[idx] for idx in inner_output_idx["shared"]
130+
]
131+
new_carry["shared"] = shared_inner_outputs
132+
133+
if "sit_sot" in inner_output_idx:
134+
sit_sot_inner_outputs = [
135+
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
136+
]
137+
new_carry["sit_sot"] = sum(
138+
[
139+
convert(carry_element, inner_outputs_element)
140+
for convert in sit_sot_to_carry
141+
for (carry_element, inner_outputs_element) in zip(
142+
carry["sit_sot"], sit_sot_inner_outputs
143+
)
144+
],
145+
[],
146+
)
147+
148+
return new_carry
111149

112150
def body_fn(carry, x):
113151
inner_in_args = scan_inner_in_args(carry, x)
114152
inner_outputs = scan_inner_func(*inner_in_args)
115153
new_carry = scan_new_carry(carry, inner_outputs)
116-
return new_carry, *inner_outputs
154+
return new_carry, inner_outputs
117155

118-
print(init_carry)
119156
_, results = jax.lax.scan(body_fn, init_carry, sequences, length=n_steps)
120157

121-
return results
158+
return results[0]
122159

123160
return scan

tests/link/jax/test_scan.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from aesara.scan.basic import scan
1313
from aesara.scan.op import Scan
1414
from aesara.tensor.math import gammaln, log
15+
from aesara.tensor.random.utils import RandomStream
1516
from aesara.tensor.type import ivector, lscalar, scalar
1617
from tests.link.jax.test_basic import compare_jax_and_py
1718

@@ -37,6 +38,20 @@ def test_sit_sot():
3738
assert np.allclose(fn(1.0), jax_fn(1.0))
3839

3940

41+
def test_nit_sot_shared():
42+
res, updates = scan(
43+
fn=lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
44+
0, 1, name="a"
45+
),
46+
n_steps=3,
47+
)
48+
49+
jax_fn = function((), res, updates=updates, mode="JAX")
50+
print(jax_fn())
51+
fn = function((), res, updates=updates)
52+
print(fn())
53+
54+
4055
@pytest.mark.parametrize(
4156
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
4257
[
@@ -73,6 +88,28 @@ def test_sit_sot():
7388
None,
7489
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
7590
),
91+
# # sit-sot
92+
# (
93+
# lambda a_tm1: 2 * a_tm1,
94+
# [],
95+
# [{"initial": at.as_tensor(0.0, dtype="floatX"), "taps": [-1]}],
96+
# [],
97+
# 3,
98+
# [],
99+
# None,
100+
# lambda op: op.info.n_sit_sot > 0,
101+
# ),
102+
# # sit-sot, while
103+
# (
104+
# lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
105+
# [],
106+
# [{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
107+
# [],
108+
# 3,
109+
# [],
110+
# None,
111+
# lambda op: op.info.n_sit_sot > 0,
112+
# ),
76113
],
77114
)
78115
def test_xit_xot_types(

0 commit comments

Comments
 (0)