Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit a24d46b

Browse files
committed
Manage indexing when there are several xit-sots
1 parent 5a884d8 commit a24d46b

File tree

2 files changed

+70
-17
lines changed

2 files changed

+70
-17
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,7 @@ def scan(*outer_inputs):
6666
n_steps = outer_in["n_steps"]
6767
sequences = outer_in["sequences"]
6868
non_sequences = outer_in["non_sequences"]
69-
init_carry = {
70-
name: outer_in[name]
71-
for name in ["mit_sot", "sit_sot", "shared"]
72-
if len(outer_in[name]) > 0
73-
}
69+
init_carry = {name: outer_in[name] for name in ["mit_sot", "sit_sot", "shared"]}
7470
init_carry["step"] = 0
7571

7672
# Map to retrieve the inner-outputs
@@ -100,16 +96,18 @@ def scan_inner_in_args(carry, x):
10096
inner_in_mit_sot = sum(
10197
[
10298
convert(step, carry_element)
103-
for convert in mit_sot_from_carry
104-
for carry_element in carry["mit_sot"]
99+
for convert, carry_element in zip(
100+
mit_sot_from_carry, carry["mit_sot"]
101+
)
105102
],
106103
[],
107104
)
108105
inner_in_sit_sot = sum(
109106
[
110107
convert(step, carry_element)
111-
for convert in sit_sot_from_carry
112-
for carry_element in carry["sit_sot"]
108+
for convert, carry_element in zip(
109+
sit_sot_from_carry, carry["sit_sot"]
110+
)
113111
],
114112
[],
115113
)
@@ -138,8 +136,8 @@ def scan_new_carry(carry, inner_outputs):
138136
[+ while-condition]
139137
140138
"""
141-
new_carry = {}
142139
step = carry["step"]
140+
new_carry = {"mit_sot": [], "sit_sot": [], "shared": []}
143141

144142
if "shared" in inner_output_idx:
145143
shared_inner_outputs = [
@@ -154,9 +152,8 @@ def scan_new_carry(carry, inner_outputs):
154152
new_carry["mit_sot"] = sum(
155153
[
156154
convert(step, carry_element, inner_outputs_element)
157-
for convert in mit_sot_to_carry
158-
for (carry_element, inner_outputs_element) in zip(
159-
carry["mit_sot"], mit_sot_inner_outputs
155+
for (convert, carry_element, inner_outputs_element) in zip(
156+
mit_sot_to_carry, carry["mit_sot"], mit_sot_inner_outputs
160157
)
161158
],
162159
[],
@@ -169,9 +166,8 @@ def scan_new_carry(carry, inner_outputs):
169166
new_carry["sit_sot"] = sum(
170167
[
171168
convert(step, carry_element, inner_outputs_element)
172-
for convert in sit_sot_to_carry
173-
for (carry_element, inner_outputs_element) in zip(
174-
carry["sit_sot"], sit_sot_inner_outputs
169+
for (convert, carry_element, inner_outputs_element) in zip(
170+
sit_sot_to_carry, carry["sit_sot"], sit_sot_inner_outputs
175171
)
176172
],
177173
[],
@@ -189,6 +185,9 @@ def body_fn(carry, x):
189185

190186
_, results = jax.lax.scan(body_fn, init_carry, sequences, length=n_steps)
191187

192-
return results[0]
188+
if len(results) == 1:
189+
return results[0]
190+
191+
return results
193192

194193
return scan

tests/link/jax/test_scan.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ def test_mit_sot():
6969
assert np.allclose(fn(), jax_fn())
7070

7171

72+
def test_mit_sot_2():
73+
res, updates = scan(
74+
fn=lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
75+
outputs_info=[
76+
{"initial": at.as_tensor(1.0, dtype="floatX"), "taps": [-1]},
77+
{"initial": at.as_tensor(0.5, dtype="floatX"), "taps": [-1]},
78+
],
79+
n_steps=10,
80+
)
81+
jax_fn = function((), res, updates=updates, mode="JAX")
82+
fn = function((), res, updates=updates)
83+
print(jax_fn())
84+
print(fn())
85+
assert np.allclose(fn(), jax_fn())
86+
87+
7288
@pytest.mark.parametrize(
7389
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
7490
[
@@ -126,6 +142,44 @@ def test_mit_sot():
126142
# None,
127143
# lambda op: op.info.n_sit_sot > 0,
128144
# ),
145+
# # nit-sot, shared input/output
146+
# (
147+
# lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
148+
# 0, 1, name="a"
149+
# ),
150+
# [],
151+
# [{}],
152+
# [],
153+
# 3,
154+
# [],
155+
# [np.array([-1.63408257, 0.18046406, 2.43265803])],
156+
# lambda op: op.info.n_shared_outs > 0,
157+
# ),
158+
# # mit-sot (that's also a type of sit-sot)
159+
# (
160+
# lambda a_tm1: 2 * a_tm1,
161+
# [],
162+
# [{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}],
163+
# [],
164+
# 6,
165+
# [],
166+
# None,
167+
# lambda op: op.info.n_mit_sot > 0,
168+
# ),
169+
# # mit-sot
170+
# (
171+
# lambda a_tm1, b_tm1: (2 * a_tm1, 2 * b_tm1),
172+
# [],
173+
# [
174+
# {"initial": at.as_tensor(1.0, dtype="floatX"), "taps": [-1]},
175+
# {"initial": at.as_tensor(0.3, dtype="floatX"), "taps": [-1]},
176+
# ],
177+
# [],
178+
# 10,
179+
# [],
180+
# None,
181+
# lambda op: op.info.n_mit_sot > 0,
182+
# ),
129183
],
130184
)
131185
def test_xit_xot_types(

0 commit comments

Comments
 (0)