Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 0932c8e

Browse files
committed
Support mit-mots in the JAX backend
1 parent 2fd8122 commit 0932c8e

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def assert_while_returns_last_output(fgraph, node):
5757
def jax_funcify_Scan(op, node, **kwargs):
5858
scan_inner_fn = jax_funcify(op.fgraph)
5959
input_taps = {
60+
"mit_mot": op.info.mit_mot_in_slices,
6061
"mit_sot": op.info.mit_sot_in_slices,
6162
"sit_sot": op.info.sit_sot_in_slices,
6263
"nit_sot": op.info.sit_sot_in_slices,
@@ -76,9 +77,6 @@ def parse_outer_inputs(outer_inputs):
7677
"shared": list(op.outer_shared(outer_inputs)),
7778
"non_sequences": list(op.outer_non_seqs(outer_inputs)),
7879
}
79-
if len(outer_in["mit_mot"]) > 0:
80-
raise NotImplementedError("mit-mot not supported")
81-
8280
return outer_in
8381

8482
if op.info.as_while:
@@ -364,7 +362,7 @@ def build_jax_scan_inputs(outer_in: Dict):
364362
sequences = outer_in["sequences"]
365363
init_carry = {
366364
name: outer_in[name]
367-
for name in ["mit_sot", "sit_sot", "shared", "non_sequences"]
365+
for name in ["mit_mot", "mit_sot", "sit_sot", "shared", "non_sequences"]
368366
}
369367
init_carry["step"] = 0
370368
return n_steps, sequences, init_carry
@@ -381,7 +379,7 @@ def build_inner_outputs_map(outer_in):
381379
[+ while-condition]
382380
383381
"""
384-
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
382+
inner_outputs_names = ["mit_mot", "mit_sot", "sit_sot", "nit_sot", "shared"]
385383

386384
offset = 0
387385
inner_output_idx = defaultdict(list)
@@ -456,6 +454,9 @@ def scan_inner_in_args(carry, x):
456454
current_step = carry["step"]
457455

458456
inner_in_seqs = x
457+
inner_in_mit_mot = from_carry_storage(
458+
carry["mit_mot"], current_step, input_taps["mit_mot"]
459+
)
459460
inner_in_mit_sot = from_carry_storage(
460461
carry["mit_sot"], current_step, input_taps["mit_sot"]
461462
)
@@ -468,6 +469,7 @@ def scan_inner_in_args(carry, x):
468469
return sum(
469470
[
470471
inner_in_seqs,
472+
inner_in_mit_mot,
471473
inner_in_mit_sot,
472474
inner_in_sit_sot,
473475
inner_in_shared,
@@ -480,6 +482,7 @@ def scan_new_carry(carry, inner_outputs):
480482
"""Create a new carry value from the values returned by the inner function (inner-outputs)."""
481483
step = carry["step"]
482484
new_carry = {
485+
"mit_mot": [],
483486
"mit_sot": [],
484487
"sit_sot": [],
485488
"shared": [],
@@ -493,6 +496,14 @@ def scan_new_carry(carry, inner_outputs):
493496
]
494497
new_carry["shared"] = shared_inner_outputs
495498

499+
if "mit_mot" in inner_output_idx:
500+
mit_mot_inner_outputs = [
501+
inner_outputs[idx] for idx in inner_output_idx["mit_mot"]
502+
]
503+
new_carry["mit_mot"] = to_carry_storage(
504+
mit_mot_inner_outputs, carry["mit_mot"], step, input_taps["mit_mot"]
505+
)
506+
496507
if "mit_sot" in inner_output_idx:
497508
mit_sot_inner_outputs = [
498509
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
@@ -527,6 +538,10 @@ def scan_new_outputs(inner_outputs):
527538
528539
"""
529540
outer_outputs = []
541+
if "mit_mot" in inner_output_idx:
542+
outer_outputs.append(
543+
[inner_outputs[idx] for idx in inner_output_idx["mit_mot"]]
544+
)
530545
if "mit_sot" in inner_output_idx:
531546
outer_outputs.append(
532547
[inner_outputs[idx] for idx in inner_output_idx["mit_sot"]]

tests/link/jax/test_scan.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from packaging.version import parse as version_parse
44

55
import aesara.tensor as at
6-
from aesara import function
6+
from aesara import function, grad
77
from aesara.compile.mode import Mode
88
from aesara.configdefaults import config
99
from aesara.graph.fg import FunctionGraph
@@ -27,11 +27,10 @@
2727

2828

2929
def test_while_cannnot_use_all_outputs():
30-
"""The JAX backend cannot return all the outputs of a while loop.
30+
"""The JAX backend cannot use all the outputs of a while loop.
3131
3232
Indeed, JAX has fundamental limitations that prevent it from returning
3333
all the intermediate results computed in a `jax.lax.while_loop` loop.
34-
3534
"""
3635
res, updates = scan(
3736
fn=lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
@@ -233,6 +232,16 @@ def test_sequence_opt():
233232
# [],
234233
# 3,
235234
# [],
235+
# lambda op: op.info.n_sit_sot > 0,
236+
# ),
237+
# # sit-sot, while
238+
# (
239+
# lambda a_tm1: (a_tm1 + 1, until(a_tm1 > 2)),
240+
# [],
241+
# [{"initial": at.as_tensor(1, dtype=np.int64), "taps": [-1]}],
242+
# [],
243+
# 3,
244+
# [],
236245
# None,
237246
# lambda op: op.info.n_sit_sot > 0,
238247
# ),
@@ -478,3 +487,30 @@ def power_step(prior_result, x):
478487

479488
for output_jax, output in zip(jax_res, res):
480489
assert np.allclose(jax_res, res)
490+
491+
492+
@pytest.mark.xfail(reason="Fails for reasons unrelated to `Scan`")
493+
def test_mitmots_basic():
494+
495+
init_x = at.dvector()
496+
seq = at.dvector()
497+
498+
def inner_fct(seq, state_old, state_current):
499+
return state_old * 2 + state_current + seq
500+
501+
out, _ = scan(
502+
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
503+
)
504+
505+
g_outs = grad(out.sum(), [seq, init_x])
506+
507+
out_fg = FunctionGraph([seq, init_x], g_outs)
508+
509+
seq_val = np.arange(3)
510+
init_x_val = np.r_[-2, -1]
511+
(seq_val, init_x_val)
512+
513+
fn = function(out_fg.inputs, out_fg.outputs)
514+
jax_fn = function(out_fg.inputs, out_fg.outputs, mode="JAX")
515+
print(fn(seq_val, init_x_val))
516+
print(jax_fn(seq_val, init_x_val))

0 commit comments

Comments
 (0)