1+ from collections import defaultdict
2+
13import jax
24
35from aesara .link .jax .dispatch .basic import jax_funcify
@@ -10,7 +12,8 @@ def jax_funcify_Scan(op, node, **kwargs):
1012 if op .info .as_while :
1113 raise NotImplementedError ("While loops are not supported in the JAX backend." )
1214
13- input_taps = op .info .sit_sot_in_slices
15+ sit_sot_input_taps = op .info .sit_sot_in_slices
16+ mit_sot_input_taps = op .info .mit_sot_in_slices
1417
1518 # Construct `scan_inner_func`'s arguments from the carry value and sequence
1619 # element passed to `body_fn`.
@@ -19,69 +22,65 @@ def jax_funcify_Scan(op, node, **kwargs):
1922 # inputs taps.
2023 def index_carry_arrays (input_taps ):
2124 """Fetch the inner inputs from the values stored in the carry array"""
25+ # TODO: Check and refactor this
2226 storage_size = - min (input_taps )
2327 offsets = [storage_size + tap for tap in input_taps ]
2428
25- def to_inner_inputs (carry ):
26- return [carry [offset ] for offset in offsets ]
29+ def to_inner_inputs (step , carry ):
30+ return [carry [step + offset ] for offset in offsets ]
2731
2832 return to_inner_inputs
2933
30- sit_sot_from_carry = [index_carry_arrays (tap ) for tap in input_taps ]
34+ sit_sot_from_carry = [index_carry_arrays (tap ) for tap in sit_sot_input_taps ]
35+ mit_sot_from_carry = [index_carry_arrays (tap ) for tap in mit_sot_input_taps ]
3136
3237 # Construct the new carry values from the outputs of `scan_inner_func`
3338 def inner_outputs_to_carry (input_taps ):
3439 """Create the new carry array from the inner output"""
35- storage_size = - min (input_taps )
36- offsets = [storage_size + tap for tap in input_taps ]
37-
38- def to_new_carry (carry , inner_outputs ):
39- return [carry .at [offset ].set (inner_outputs ) for offset in offsets ]
40+ # TODO: Check and refactor this
41+ def to_new_carry (step , carry , inner_outputs ):
42+ return [carry .at [step - tap ].set (inner_outputs ) for tap in input_taps ]
4043
4144 return to_new_carry
4245
43- sit_sot_to_carry = [inner_outputs_to_carry (tap ) for tap in input_taps ]
46+ sit_sot_to_carry = [inner_outputs_to_carry (tap ) for tap in sit_sot_input_taps ]
47+ mit_sot_to_carry = [inner_outputs_to_carry (tap ) for tap in mit_sot_input_taps ]
4448
4549 def scan (* outer_inputs ):
4650
47- n_steps = outer_inputs [0 ]
48- outer_in_seqs = list (op .outer_seqs (outer_inputs ))
49- outer_in_mit_mot = list (op .outer_mitmot (outer_inputs ))
50- outer_in_mit_sot = list (op .outer_mitsot (outer_inputs ))
51- outer_in_nit_sot = list (op .outer_nitsot (outer_inputs ))
52- outer_in_sit_sot = list (op .outer_sitsot (outer_inputs ))
53- outer_in_shared = list (op .outer_shared (outer_inputs ))
54- outer_in_non_seqs = list (op .outer_non_seqs (outer_inputs ))
55- if len (outer_in_mit_mot ):
51+ # Inputs to `aesara.scan`
52+ outer_in = {
53+ "n_steps" : outer_inputs [0 ],
54+ "sequences" : list (op .outer_seqs (outer_inputs )),
55+ "mit_mot" : list (op .outer_mitmot (outer_inputs )),
56+ "mit_sot" : list (op .outer_mitsot (outer_inputs )),
57+ "nit_sot" : list (op .outer_nitsot (outer_inputs )),
58+ "sit_sot" : list (op .outer_sitsot (outer_inputs )),
59+ "shared" : list (op .outer_shared (outer_inputs )),
60+ "non_sequences" : list (op .outer_non_seqs (outer_inputs )),
61+ }
62+ if len (outer_in ["mit_mot" ]) > 0 :
5663 raise NotImplementedError ("mit-mot not supported" )
57- if len (outer_in_mit_sot ):
58- raise NotImplementedError ("mit-sot not supported" )
59-
60- # These are the outer-inputs
61- sequences = outer_in_seqs
62- non_sequences = outer_in_non_seqs
63- init_carry = {}
64- for name , outputs in [
65- ("sit_sot" , outer_in_sit_sot ),
66- ("shared" , outer_in_shared ),
67- ]:
68- if len (outputs ) > 0 :
69- init_carry [name ] = outputs
70-
71- # We keep track of the kind of inner_outputs and their number
72- from collections import defaultdict
7364
65+ # Inputs to `jax.lax.scan`
66+ n_steps = outer_in ["n_steps" ]
67+ sequences = outer_in ["sequences" ]
68+ non_sequences = outer_in ["non_sequences" ]
69+ init_carry = {
70+ name : outer_in [name ]
71+ for name in ["mit_sot" , "sit_sot" , "shared" ]
72+ if len (outer_in [name ]) > 0
73+ }
74+ init_carry ["step" ] = 0
75+
76+ # Map to retrieve the inner-outputs
7477 offset = 0
7578 inner_output_idx = defaultdict (list )
76- for name , outputs in [
77- ("sit_sot" , outer_in_sit_sot ),
78- ("nit_sot" , outer_in_nit_sot ),
79- ("shared" , outer_in_shared ),
80- ]:
81- if len (outputs ) > 0 :
82- for i in range (len (outputs )):
83- inner_output_idx [name ].append (offset + i )
84- offset += len (outputs )
79+ for name in ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]:
80+ num_outputs = len (outer_in [name ])
81+ for i in range (num_outputs ):
82+ inner_output_idx [name ].append (offset + i )
83+ offset += num_outputs
8584
8685 def scan_inner_in_args (carry , x ):
8786 """Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
@@ -95,10 +94,20 @@ def scan_inner_in_args(carry, x):
9594 - non-sequences
9695
9796 """
97+ step = carry ["step" ]
98+
9899 inner_in_seqs = x
100+ inner_in_mit_sot = sum (
101+ [
102+ convert (step , carry_element )
103+ for convert in mit_sot_from_carry
104+ for carry_element in carry ["mit_sot" ]
105+ ],
106+ [],
107+ )
99108 inner_in_sit_sot = sum (
100109 [
101- convert (carry_element )
110+ convert (step , carry_element )
102111 for convert in sit_sot_from_carry
103112 for carry_element in carry ["sit_sot" ]
104113 ],
@@ -107,7 +116,14 @@ def scan_inner_in_args(carry, x):
107116 inner_in_shared = carry .get ("shared" , [])
108117
109118 return sum (
110- [inner_in_seqs , inner_in_sit_sot , inner_in_shared , non_sequences ], []
119+ [
120+ inner_in_seqs ,
121+ inner_in_mit_sot ,
122+ inner_in_sit_sot ,
123+ inner_in_shared ,
124+ non_sequences ,
125+ ],
126+ [],
111127 )
112128
113129 def scan_new_carry (carry , inner_outputs ):
@@ -123,20 +139,36 @@ def scan_new_carry(carry, inner_outputs):
123139
124140 """
125141 new_carry = {}
142+ step = carry ["step" ]
126143
127144 if "shared" in inner_output_idx :
128145 shared_inner_outputs = [
129146 inner_outputs [idx ] for idx in inner_output_idx ["shared" ]
130147 ]
131148 new_carry ["shared" ] = shared_inner_outputs
132149
150+ if "mit_sot" in inner_output_idx :
151+ mit_sot_inner_outputs = [
152+ inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]
153+ ]
154+ new_carry ["mit_sot" ] = sum (
155+ [
156+ convert (step , carry_element , inner_outputs_element )
157+ for convert in mit_sot_to_carry
158+ for (carry_element , inner_outputs_element ) in zip (
159+ carry ["mit_sot" ], mit_sot_inner_outputs
160+ )
161+ ],
162+ [],
163+ )
164+
133165 if "sit_sot" in inner_output_idx :
134166 sit_sot_inner_outputs = [
135167 inner_outputs [idx ] for idx in inner_output_idx ["sit_sot" ]
136168 ]
137169 new_carry ["sit_sot" ] = sum (
138170 [
139- convert (carry_element , inner_outputs_element )
171+ convert (step , carry_element , inner_outputs_element )
140172 for convert in sit_sot_to_carry
141173 for (carry_element , inner_outputs_element ) in zip (
142174 carry ["sit_sot" ], sit_sot_inner_outputs
@@ -145,6 +177,8 @@ def scan_new_carry(carry, inner_outputs):
145177 [],
146178 )
147179
180+ new_carry ["step" ] = carry ["step" ] + 1
181+
148182 return new_carry
149183
150184 def body_fn (carry , x ):
0 commit comments