@@ -57,6 +57,7 @@ def assert_while_returns_last_output(fgraph, node):
5757def jax_funcify_Scan (op , node , ** kwargs ):
5858 scan_inner_fn = jax_funcify (op .fgraph )
5959 input_taps = {
60+ "mit_mot" : op .info .mit_mot_in_slices ,
6061 "mit_sot" : op .info .mit_sot_in_slices ,
6162 "sit_sot" : op .info .sit_sot_in_slices ,
6263 "nit_sot" : op .info .sit_sot_in_slices ,
@@ -76,9 +77,6 @@ def parse_outer_inputs(outer_inputs):
7677 "shared" : list (op .outer_shared (outer_inputs )),
7778 "non_sequences" : list (op .outer_non_seqs (outer_inputs )),
7879 }
79- if len (outer_in ["mit_mot" ]) > 0 :
80- raise NotImplementedError ("mit-mot not supported" )
81-
8280 return outer_in
8381
8482 if op .info .as_while :
@@ -364,7 +362,7 @@ def build_jax_scan_inputs(outer_in: Dict):
364362 sequences = outer_in ["sequences" ]
365363 init_carry = {
366364 name : outer_in [name ]
367- for name in ["mit_sot" , "sit_sot" , "shared" , "non_sequences" ]
365+ for name in ["mit_mot" , " mit_sot" , "sit_sot" , "shared" , "non_sequences" ]
368366 }
369367 init_carry ["step" ] = 0
370368 return n_steps , sequences , init_carry
@@ -381,7 +379,7 @@ def build_inner_outputs_map(outer_in):
381379 [+ while-condition]
382380
383381 """
384- inner_outputs_names = ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
382+ inner_outputs_names = ["mit_mot" , " mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
385383
386384 offset = 0
387385 inner_output_idx = defaultdict (list )
@@ -456,6 +454,9 @@ def scan_inner_in_args(carry, x):
456454 current_step = carry ["step" ]
457455
458456 inner_in_seqs = x
457+ inner_in_mit_mot = from_carry_storage (
458+ carry ["mit_mot" ], current_step , input_taps ["mit_mot" ]
459+ )
459460 inner_in_mit_sot = from_carry_storage (
460461 carry ["mit_sot" ], current_step , input_taps ["mit_sot" ]
461462 )
@@ -468,6 +469,7 @@ def scan_inner_in_args(carry, x):
468469 return sum (
469470 [
470471 inner_in_seqs ,
472+ inner_in_mit_mot ,
471473 inner_in_mit_sot ,
472474 inner_in_sit_sot ,
473475 inner_in_shared ,
@@ -480,6 +482,7 @@ def scan_new_carry(carry, inner_outputs):
480482 """Create a new carry value from the values returned by the inner function (inner-outputs)."""
481483 step = carry ["step" ]
482484 new_carry = {
485+ "mit_mot" : [],
483486 "mit_sot" : [],
484487 "sit_sot" : [],
485488 "shared" : [],
@@ -493,6 +496,14 @@ def scan_new_carry(carry, inner_outputs):
493496 ]
494497 new_carry ["shared" ] = shared_inner_outputs
495498
499+ if "mit_mot" in inner_output_idx :
500+ mit_mot_inner_outputs = [
501+ inner_outputs [idx ] for idx in inner_output_idx ["mit_mot" ]
502+ ]
503+ new_carry ["mit_mot" ] = to_carry_storage (
504+ mit_mot_inner_outputs , carry ["mit_mot" ], step , input_taps ["mit_mot" ]
505+ )
506+
496507 if "mit_sot" in inner_output_idx :
497508 mit_sot_inner_outputs = [
498509 inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]
@@ -527,6 +538,10 @@ def scan_new_outputs(inner_outputs):
527538
528539 """
529540 outer_outputs = []
541+ if "mit_mot" in inner_output_idx :
542+ outer_outputs .append (
543+ [inner_outputs [idx ] for idx in inner_output_idx ["mit_mot" ]]
544+ )
530545 if "mit_sot" in inner_output_idx :
531546 outer_outputs .append (
532547 [inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]]
0 commit comments