Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 5a884d8

Browse files
committed
Manage indexing of carried arrays for mit-sots
1 parent 8df4879 commit 5a884d8

File tree

2 files changed

+98
-48
lines changed

2 files changed

+98
-48
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 81 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
import jax
24

35
from aesara.link.jax.dispatch.basic import jax_funcify
@@ -10,7 +12,8 @@ def jax_funcify_Scan(op, node, **kwargs):
1012
if op.info.as_while:
1113
raise NotImplementedError("While loops are not supported in the JAX backend.")
1214

13-
input_taps = op.info.sit_sot_in_slices
15+
sit_sot_input_taps = op.info.sit_sot_in_slices
16+
mit_sot_input_taps = op.info.mit_sot_in_slices
1417

1518
# Construct `scan_inner_func`'s arguments from the carry value and sequence
1619
# element passed to `body_fn`.
@@ -19,69 +22,65 @@ def jax_funcify_Scan(op, node, **kwargs):
1922
# inputs taps.
2023
def index_carry_arrays(input_taps):
2124
"""Fetch the inner inputs from the values stored in the carry array"""
25+
# TODO: Check and refactor this
2226
storage_size = -min(input_taps)
2327
offsets = [storage_size + tap for tap in input_taps]
2428

25-
def to_inner_inputs(carry):
26-
return [carry[offset] for offset in offsets]
29+
def to_inner_inputs(step, carry):
30+
return [carry[step + offset] for offset in offsets]
2731

2832
return to_inner_inputs
2933

30-
sit_sot_from_carry = [index_carry_arrays(tap) for tap in input_taps]
34+
sit_sot_from_carry = [index_carry_arrays(tap) for tap in sit_sot_input_taps]
35+
mit_sot_from_carry = [index_carry_arrays(tap) for tap in mit_sot_input_taps]
3136

3237
# Construct the new carry values from the outputs of `scan_inner_func`
3338
def inner_outputs_to_carry(input_taps):
3439
"""Create the new carry array from the inner output"""
35-
storage_size = -min(input_taps)
36-
offsets = [storage_size + tap for tap in input_taps]
37-
38-
def to_new_carry(carry, inner_outputs):
39-
return [carry.at[offset].set(inner_outputs) for offset in offsets]
40+
# TODO: Check and refactor this
41+
def to_new_carry(step, carry, inner_outputs):
42+
return [carry.at[step - tap].set(inner_outputs) for tap in input_taps]
4043

4144
return to_new_carry
4245

43-
sit_sot_to_carry = [inner_outputs_to_carry(tap) for tap in input_taps]
46+
sit_sot_to_carry = [inner_outputs_to_carry(tap) for tap in sit_sot_input_taps]
47+
mit_sot_to_carry = [inner_outputs_to_carry(tap) for tap in mit_sot_input_taps]
4448

4549
def scan(*outer_inputs):
4650

47-
n_steps = outer_inputs[0]
48-
outer_in_seqs = list(op.outer_seqs(outer_inputs))
49-
outer_in_mit_mot = list(op.outer_mitmot(outer_inputs))
50-
outer_in_mit_sot = list(op.outer_mitsot(outer_inputs))
51-
outer_in_nit_sot = list(op.outer_nitsot(outer_inputs))
52-
outer_in_sit_sot = list(op.outer_sitsot(outer_inputs))
53-
outer_in_shared = list(op.outer_shared(outer_inputs))
54-
outer_in_non_seqs = list(op.outer_non_seqs(outer_inputs))
55-
if len(outer_in_mit_mot):
51+
# Inputs to `aesara.scan`
52+
outer_in = {
53+
"n_steps": outer_inputs[0],
54+
"sequences": list(op.outer_seqs(outer_inputs)),
55+
"mit_mot": list(op.outer_mitmot(outer_inputs)),
56+
"mit_sot": list(op.outer_mitsot(outer_inputs)),
57+
"nit_sot": list(op.outer_nitsot(outer_inputs)),
58+
"sit_sot": list(op.outer_sitsot(outer_inputs)),
59+
"shared": list(op.outer_shared(outer_inputs)),
60+
"non_sequences": list(op.outer_non_seqs(outer_inputs)),
61+
}
62+
if len(outer_in["mit_mot"]) > 0:
5663
raise NotImplementedError("mit-mot not supported")
57-
if len(outer_in_mit_sot):
58-
raise NotImplementedError("mit-sot not supported")
59-
60-
# These are the outer-inputs
61-
sequences = outer_in_seqs
62-
non_sequences = outer_in_non_seqs
63-
init_carry = {}
64-
for name, outputs in [
65-
("sit_sot", outer_in_sit_sot),
66-
("shared", outer_in_shared),
67-
]:
68-
if len(outputs) > 0:
69-
init_carry[name] = outputs
70-
71-
# We keep track of the kind of inner_outputs and their number
72-
from collections import defaultdict
7364

65+
# Inputs to `jax.lax.scan`
66+
n_steps = outer_in["n_steps"]
67+
sequences = outer_in["sequences"]
68+
non_sequences = outer_in["non_sequences"]
69+
init_carry = {
70+
name: outer_in[name]
71+
for name in ["mit_sot", "sit_sot", "shared"]
72+
if len(outer_in[name]) > 0
73+
}
74+
init_carry["step"] = 0
75+
76+
# Map to retrieve the inner-outputs
7477
offset = 0
7578
inner_output_idx = defaultdict(list)
76-
for name, outputs in [
77-
("sit_sot", outer_in_sit_sot),
78-
("nit_sot", outer_in_nit_sot),
79-
("shared", outer_in_shared),
80-
]:
81-
if len(outputs) > 0:
82-
for i in range(len(outputs)):
83-
inner_output_idx[name].append(offset + i)
84-
offset += len(outputs)
79+
for name in ["mit_sot", "sit_sot", "nit_sot", "shared"]:
80+
num_outputs = len(outer_in[name])
81+
for i in range(num_outputs):
82+
inner_output_idx[name].append(offset + i)
83+
offset += num_outputs
8584

8685
def scan_inner_in_args(carry, x):
8786
"""Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
@@ -95,10 +94,20 @@ def scan_inner_in_args(carry, x):
9594
- non-sequences
9695
9796
"""
97+
step = carry["step"]
98+
9899
inner_in_seqs = x
100+
inner_in_mit_sot = sum(
101+
[
102+
convert(step, carry_element)
103+
for convert in mit_sot_from_carry
104+
for carry_element in carry["mit_sot"]
105+
],
106+
[],
107+
)
99108
inner_in_sit_sot = sum(
100109
[
101-
convert(carry_element)
110+
convert(step, carry_element)
102111
for convert in sit_sot_from_carry
103112
for carry_element in carry["sit_sot"]
104113
],
@@ -107,7 +116,14 @@ def scan_inner_in_args(carry, x):
107116
inner_in_shared = carry.get("shared", [])
108117

109118
return sum(
110-
[inner_in_seqs, inner_in_sit_sot, inner_in_shared, non_sequences], []
119+
[
120+
inner_in_seqs,
121+
inner_in_mit_sot,
122+
inner_in_sit_sot,
123+
inner_in_shared,
124+
non_sequences,
125+
],
126+
[],
111127
)
112128

113129
def scan_new_carry(carry, inner_outputs):
@@ -123,20 +139,36 @@ def scan_new_carry(carry, inner_outputs):
123139
124140
"""
125141
new_carry = {}
142+
step = carry["step"]
126143

127144
if "shared" in inner_output_idx:
128145
shared_inner_outputs = [
129146
inner_outputs[idx] for idx in inner_output_idx["shared"]
130147
]
131148
new_carry["shared"] = shared_inner_outputs
132149

150+
if "mit_sot" in inner_output_idx:
151+
mit_sot_inner_outputs = [
152+
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
153+
]
154+
new_carry["mit_sot"] = sum(
155+
[
156+
convert(step, carry_element, inner_outputs_element)
157+
for convert in mit_sot_to_carry
158+
for (carry_element, inner_outputs_element) in zip(
159+
carry["mit_sot"], mit_sot_inner_outputs
160+
)
161+
],
162+
[],
163+
)
164+
133165
if "sit_sot" in inner_output_idx:
134166
sit_sot_inner_outputs = [
135167
inner_outputs[idx] for idx in inner_output_idx["sit_sot"]
136168
]
137169
new_carry["sit_sot"] = sum(
138170
[
139-
convert(carry_element, inner_outputs_element)
171+
convert(step, carry_element, inner_outputs_element)
140172
for convert in sit_sot_to_carry
141173
for (carry_element, inner_outputs_element) in zip(
142174
carry["sit_sot"], sit_sot_inner_outputs
@@ -145,6 +177,8 @@ def scan_new_carry(carry, inner_outputs):
145177
[],
146178
)
147179

180+
new_carry["step"] = carry["step"] + 1
181+
148182
return new_carry
149183

150184
def body_fn(carry, x):

tests/link/jax/test_scan.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def test_sit_sot():
3838
assert np.allclose(fn(1.0), jax_fn(1.0))
3939

4040

41+
@pytest.mark.xfail(
42+
reason="Returns correct results but raises exception due to stucture of shared variable."
43+
)
4144
def test_nit_sot_shared():
4245
res, updates = scan(
4346
fn=lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
@@ -52,6 +55,20 @@ def test_nit_sot_shared():
5255
print(fn())
5356

5457

58+
def test_mit_sot():
59+
res, updates = scan(
60+
fn=lambda a_tm1: 2 * a_tm1,
61+
outputs_info=[
62+
{"initial": at.as_tensor([0.0, 1.0], dtype="floatX"), "taps": [-2]}
63+
],
64+
n_steps=6,
65+
)
66+
67+
jax_fn = function((), res, updates=updates, mode="JAX")
68+
fn = function((), res, updates=updates)
69+
assert np.allclose(fn(), jax_fn())
70+
71+
5572
@pytest.mark.parametrize(
5673
"fn, sequences, outputs_info, non_sequences, n_steps, input_vals, output_vals, op_check",
5774
[
@@ -96,7 +113,6 @@ def test_nit_sot_shared():
96113
# [],
97114
# 3,
98115
# [],
99-
# None,
100116
# lambda op: op.info.n_sit_sot > 0,
101117
# ),
102118
# # sit-sot, while

0 commit comments

Comments
 (0)