Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit e528e44

Browse files
committed
Support mit-mots in the JAX backend
1 parent 988f904 commit e528e44

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

aesara/link/jax/dispatch/scan.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def assert_while_returns_last_output(fgraph, node):
5757
def jax_funcify_Scan(op, node, **kwargs):
5858
scan_inner_fn = jax_funcify(op.fgraph)
5959
input_taps = {
60+
"mit_mot": op.info.mit_mot_in_slices,
6061
"mit_sot": op.info.mit_sot_in_slices,
6162
"sit_sot": op.info.sit_sot_in_slices,
6263
"nit_sot": op.info.sit_sot_in_slices,
@@ -76,9 +77,6 @@ def parse_outer_inputs(outer_inputs):
7677
"shared": list(op.outer_shared(outer_inputs)),
7778
"non_sequences": list(op.outer_non_seqs(outer_inputs)),
7879
}
79-
if len(outer_in["mit_mot"]) > 0:
80-
raise NotImplementedError("mit-mot not supported")
81-
8280
return outer_in
8381

8482
if op.info.as_while:
@@ -364,7 +362,7 @@ def build_jax_scan_inputs(outer_in: Dict):
364362
sequences = outer_in["sequences"]
365363
init_carry = {
366364
name: outer_in[name]
367-
for name in ["mit_sot", "sit_sot", "shared", "non_sequences"]
365+
for name in ["mit_mot", "mit_sot", "sit_sot", "shared", "non_sequences"]
368366
}
369367
init_carry["step"] = 0
370368
return n_steps, sequences, init_carry
@@ -381,7 +379,7 @@ def build_inner_outputs_map(outer_in):
381379
[+ while-condition]
382380
383381
"""
384-
inner_outputs_names = ["mit_sot", "sit_sot", "nit_sot", "shared"]
382+
inner_outputs_names = ["mit_mot", "mit_sot", "sit_sot", "nit_sot", "shared"]
385383

386384
offset = 0
387385
inner_output_idx = defaultdict(list)
@@ -456,6 +454,9 @@ def scan_inner_in_args(carry, x):
456454
current_step = carry["step"]
457455

458456
inner_in_seqs = x
457+
inner_in_mit_mot = from_carry_storage(
458+
carry["mit_mot"], current_step, input_taps["mit_mot"]
459+
)
459460
inner_in_mit_sot = from_carry_storage(
460461
carry["mit_sot"], current_step, input_taps["mit_sot"]
461462
)
@@ -468,6 +469,7 @@ def scan_inner_in_args(carry, x):
468469
return sum(
469470
[
470471
inner_in_seqs,
472+
inner_in_mit_mot,
471473
inner_in_mit_sot,
472474
inner_in_sit_sot,
473475
inner_in_shared,
@@ -480,6 +482,7 @@ def scan_new_carry(carry, inner_outputs):
480482
"""Create a new carry value from the values returned by the inner function (inner-outputs)."""
481483
step = carry["step"]
482484
new_carry = {
485+
"mit_mot": [],
483486
"mit_sot": [],
484487
"sit_sot": [],
485488
"shared": [],
@@ -493,6 +496,14 @@ def scan_new_carry(carry, inner_outputs):
493496
]
494497
new_carry["shared"] = shared_inner_outputs
495498

499+
if "mit_mot" in inner_output_idx:
500+
mit_mot_inner_outputs = [
501+
inner_outputs[idx] for idx in inner_output_idx["mit_mot"]
502+
]
503+
new_carry["mit_mot"] = to_carry_storage(
504+
mit_mot_inner_outputs, carry["mit_mot"], step, input_taps["mit_mot"]
505+
)
506+
496507
if "mit_sot" in inner_output_idx:
497508
mit_sot_inner_outputs = [
498509
inner_outputs[idx] for idx in inner_output_idx["mit_sot"]
@@ -527,6 +538,10 @@ def scan_new_outputs(inner_outputs):
527538
528539
"""
529540
outer_outputs = []
541+
if "mit_mot" in inner_output_idx:
542+
outer_outputs.append(
543+
[inner_outputs[idx] for idx in inner_output_idx["mit_mot"]]
544+
)
530545
if "mit_sot" in inner_output_idx:
531546
outer_outputs.append(
532547
[inner_outputs[idx] for idx in inner_output_idx["mit_sot"]]

tests/link/jax/test_scan.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from packaging.version import parse as version_parse
44

55
import aesara.tensor as at
6-
from aesara import function
6+
from aesara import function, grad
77
from aesara.compile.mode import Mode
88
from aesara.configdefaults import config
99
from aesara.graph.fg import FunctionGraph
@@ -23,6 +23,7 @@
2323
# Disable all optimizations
2424
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
2525
jax_mode = Mode(JAXLinker(), opts)
26+
py_mode = Mode("py", opts)
2627

2728

2829
def test_while_cannnot_use_all_outputs():
@@ -66,8 +67,8 @@ def test_sit_sot():
6667
n_steps=3,
6768
)
6869

69-
jax_fn = function((a_at,), res, updates=updates, mode="JAX")
7070
fn = function((a_at,), res, updates=updates)
71+
jax_fn = function((a_at,), res, updates=updates, mode="JAX")
7172
assert np.allclose(fn(1.0), jax_fn(1.0))
7273

7374

@@ -413,3 +414,29 @@ def power_step(prior_result, x):
413414

414415
for output_jax, output in zip(jax_res, res):
415416
assert np.allclose(jax_res, res)
417+
418+
419+
def test_mitmots_basic():
420+
421+
init_x = at.dvector()
422+
seq = at.dvector()
423+
424+
def inner_fct(seq, state_old, state_current):
425+
return state_old * 2 + state_current + seq
426+
427+
out, _ = scan(
428+
inner_fct, sequences=seq, outputs_info={"initial": init_x, "taps": [-2, -1]}
429+
)
430+
431+
g_outs = grad(out.sum(), [seq, init_x])
432+
433+
out_fg = FunctionGraph([seq, init_x], g_outs)
434+
435+
seq_val = np.arange(3)
436+
init_x_val = np.r_[-2, -1]
437+
(seq_val, init_x_val)
438+
439+
fn = function(out_fg.inputs, out_fg.outputs)
440+
jax_fn = function(out_fg.inputs, out_fg.outputs, mode="JAX")
441+
print(fn(seq_val, init_x_val))
442+
print(jax_fn(seq_val, init_x_val))

0 commit comments

Comments
 (0)