11from collections import defaultdict
2+ from typing import Callable , Dict , List
23
34import jax
45
56from aesara .link .jax .dispatch .basic import jax_funcify
67from aesara .scan .op import Scan
8+ from aesara .tensor .var import TensorVariable
79
810
911@jax_funcify .register (Scan )
1012def jax_funcify_Scan (op , node , ** kwargs ):
11- scan_inner_func = jax_funcify (op .fgraph )
12- if op .info .as_while :
13- raise NotImplementedError ("While loops are not supported in the JAX backend." )
14-
15- sit_sot_input_taps = op .info .sit_sot_in_slices
16- mit_sot_input_taps = op .info .mit_sot_in_slices
17-
18- # Construct `scan_inner_func`'s arguments from the carry value and sequence
19- # element passed to `body_fn`.
20- #
21- # We need to index the storage arrays carried by `jax.lax.scan` for arguments with
22- # inputs taps.
23- def index_carry_arrays (input_taps ):
24- """Fetch the inner inputs from the values stored in the carry array"""
25- # TODO: Check and refactor this
26- storage_size = - min (input_taps )
27- offsets = [storage_size + tap for tap in input_taps ]
28-
29- def to_inner_inputs (step , carry ):
30- return [carry [step + offset ] for offset in offsets ]
31-
32- return to_inner_inputs
33-
34- sit_sot_from_carry = [index_carry_arrays (tap ) for tap in sit_sot_input_taps ]
35- mit_sot_from_carry = [index_carry_arrays (tap ) for tap in mit_sot_input_taps ]
36-
37- # Construct the new carry values from the outputs of `scan_inner_func`
38- def inner_outputs_to_carry (input_taps ):
39- """Create the new carry array from the inner output"""
40- # TODO: Check and refactor this
41- def to_new_carry (step , carry , inner_outputs ):
42- return [carry .at [step - tap ].set (inner_outputs ) for tap in input_taps ]
43-
44- return to_new_carry
45-
46- sit_sot_to_carry = [inner_outputs_to_carry (tap ) for tap in sit_sot_input_taps ]
47- mit_sot_to_carry = [inner_outputs_to_carry (tap ) for tap in mit_sot_input_taps ]
48-
49- def scan (* outer_inputs ):
50-
51- # Inputs to `aesara.scan`
13+ scan_inner_fn = jax_funcify (op .fgraph )
14+ input_taps = {
15+ "mit_sot" : op .info .mit_sot_in_slices ,
16+ "sit_sot" : op .info .sit_sot_in_slices ,
17+ }
18+
19+ # Outer-inputs are the inputs to the `Scan` apply node, built from the
20+ # the variables provided by the caller to the `scan` function at construction
21+ # time.
22+ def parse_outer_inputs (outer_inputs ):
5223 outer_in = {
5324 "n_steps" : outer_inputs [0 ],
5425 "sequences" : list (op .outer_seqs (outer_inputs )),
@@ -62,22 +33,125 @@ def scan(*outer_inputs):
6233 if len (outer_in ["mit_mot" ]) > 0 :
6334 raise NotImplementedError ("mit-mot not supported" )
6435
65- # Inputs to `jax.lax.scan`
36+ return outer_in
37+
38+ if op .info .as_while :
39+ raise NotImplementedError ("While loops are not supported in the JAX backend." )
40+ else :
41+ return make_jax_scan_fn (
42+ scan_inner_fn ,
43+ parse_outer_inputs ,
44+ input_taps ,
45+ )
46+
47+
48+ def make_jax_scan_fn (
49+ scan_inner_fn : Callable ,
50+ parse_outer_inputs : Callable [[TensorVariable ], Dict [str , List [TensorVariable ]]],
51+ input_taps : Dict ,
52+ ):
53+ """Create a `jax.lax.scan` function to perform `Scan` computations.
54+
55+ `jax.lax.scan` takes an initial `carry` value and a sequence it scans over,
56+ or a number of iterations. The first output of the loop body function, the
57+ `carry`, is carried over to the next iteration. The second, the `output`, is
58+ stacked to the previous outputs. We use this to our advantage to build
59+ `Scan` outputs without having to post-process the storage arrays.
60+
61+ The JAX scan function needs to perform the following operations:
62+ 1. Extract the inner-inputs;
63+ 2. Build the initial `carry` and `sequence` values;
64+ 3. Inside the loop:
65+ 1. `carry` + sequence elements -> inner-inputs
66+ 2. inner-outputs -> `carry`
67+ 3. inner-outputs -> `output`
68+ 4. Append the last `shared` value to the stacked `output`s
69+
70+ """
71+
72+ def build_jax_scan_inputs (outer_in : Dict ):
73+ """Build the inputs to `jax.lax.scan` from the outer-inputs."""
6674 n_steps = outer_in ["n_steps" ]
6775 sequences = outer_in ["sequences" ]
68- non_sequences = outer_in ["non_sequences" ]
69- init_carry = {name : outer_in [name ] for name in ["mit_sot" , "sit_sot" , "shared" ]}
76+ init_carry = {
77+ name : outer_in [name ]
78+ for name in ["mit_sot" , "sit_sot" , "shared" , "non_sequences" ]
79+ }
7080 init_carry ["step" ] = 0
81+ return n_steps , sequences , init_carry
82+
83+ def build_inner_outputs_map (outer_in ):
84+ """Map the inner-output variables to their position in the tuple returned by the inner function.
85+
86+ Inner-outputs are ordered as follow:
87+ - mit-mot-outputs
88+ - mit-sot-outputs
89+ - sit-sot-outputs
90+ - nit-sots (no carry)
91+ - shared-outputs
92+ [+ while-condition]
93+
94+ """
95+ inner_outputs_names = ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
7196
72- # Map to retrieve the inner-outputs
7397 offset = 0
7498 inner_output_idx = defaultdict (list )
75- for name in [ "mit_sot" , "sit_sot" , "nit_sot" , "shared" ] :
99+ for name in inner_outputs_names :
76100 num_outputs = len (outer_in [name ])
77101 for i in range (num_outputs ):
78102 inner_output_idx [name ].append (offset + i )
79103 offset += num_outputs
80104
105+ return inner_output_idx
106+
107+ def from_carry_storage (carry , step , input_taps ):
108+ """Fetch the inner inputs from the values stored in the carry array.
109+
110+ `Scan` passes storage arrays as inputs, which are then read from and
111+ updated in the loop body. At each step we need to read from this array
112+ the inputs that will be passed to the inner function.
113+
114+ This mechanism is necessary because we handle multiple-input taps within
115+ the `scan` instead of letting users manage the memory in the use cases
116+ where this is necessary.
117+
118+ """
119+
120+ def fetch (carry , step , offset ):
121+ return carry [step + offset ]
122+
123+ inner_inputs = []
124+ for taps , carry_element in zip (input_taps , carry ):
125+ storage_size = - min (taps )
126+ offsets = [storage_size + tap for tap in taps ]
127+ inner_inputs .append (
128+ [fetch (carry_element , step , offset ) for offset in offsets ]
129+ )
130+
131+ return sum (inner_inputs , [])
132+
133+ def to_carry_storage (inner_outputs , carry , step , input_taps ):
134+ """Create the new carry array from the inner output
135+
136+ `Scan` passes storage arrays as inputs, which are then read from and
137+ updated in the loop body. At each step we need to update this array
138+ with the outputs of the inner function
139+
140+ """
141+ new_carry_element = []
142+ for taps , carry_element , output in zip (input_taps , carry , inner_outputs ):
143+ new_carry_element .append (
144+ [carry_element .at [step - tap ].set (output ) for tap in taps ]
145+ )
146+
147+ return sum (new_carry_element , [])
148+
149+ def scan (* outer_inputs ):
150+
151+ outer_in = parse_outer_inputs (outer_inputs )
152+ n_steps , sequences , init_carry = build_jax_scan_inputs (outer_in )
153+ inner_output_idx = build_inner_outputs_map (outer_in )
154+
81155 def scan_inner_in_args (carry , x ):
82156 """Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
83157
@@ -90,54 +164,39 @@ def scan_inner_in_args(carry, x):
90164 - non-sequences
91165
92166 """
93- step = carry ["step" ]
167+ current_step = carry ["step" ]
94168
95169 inner_in_seqs = x
96- inner_in_mit_sot = sum (
97- [
98- convert (step , carry_element )
99- for convert , carry_element in zip (
100- mit_sot_from_carry , carry ["mit_sot" ]
101- )
102- ],
103- [],
170+ inner_in_mit_sot = from_carry_storage (
171+ carry ["mit_sot" ], current_step , input_taps ["mit_sot" ]
104172 )
105- inner_in_sit_sot = sum (
106- [
107- convert (step , carry_element )
108- for convert , carry_element in zip (
109- sit_sot_from_carry , carry ["sit_sot" ]
110- )
111- ],
112- [],
173+ inner_in_sit_sot = from_carry_storage (
174+ carry ["sit_sot" ], current_step , input_taps ["sit_sot" ]
113175 )
114176 inner_in_shared = carry .get ("shared" , [])
177+ inner_in_non_sequences = carry .get ("non_sequences" , [])
115178
116179 return sum (
117180 [
118181 inner_in_seqs ,
119182 inner_in_mit_sot ,
120183 inner_in_sit_sot ,
121184 inner_in_shared ,
122- non_sequences ,
185+ inner_in_non_sequences ,
123186 ],
124187 [],
125188 )
126189
127190 def scan_new_carry (carry , inner_outputs ):
128- """Create a new carry value from the inner-outputs.
129-
130- Inner-outputs are ordered as follow:
131- - mit-mot-outputs
132- - mit-sot-outputs
133- - sit-sot-outputs
134- - nit-sots (no carry)
135- - shared-outputs
136- [+ while-condition]
137-
138- """
191+ """Create a new carry value from the values returned by the inner function (inner-outputs)."""
139192 step = carry ["step" ]
140- new_carry = {"mit_sot" : [], "sit_sot" : [], "shared" : []}
193+ new_carry = {
194+ "mit_sot" : [],
195+ "sit_sot" : [],
196+ "shared" : [],
197+ "step" : step + 1 ,
198+ "non_sequences" : carry ["non_sequences" ],
199+ }
141200
142201 if "shared" in inner_output_idx :
143202 shared_inner_outputs = [
@@ -149,36 +208,22 @@ def scan_new_carry(carry, inner_outputs):
149208 mit_sot_inner_outputs = [
150209 inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]
151210 ]
152- new_carry ["mit_sot" ] = sum (
153- [
154- convert (step , carry_element , inner_outputs_element )
155- for (convert , carry_element , inner_outputs_element ) in zip (
156- mit_sot_to_carry , carry ["mit_sot" ], mit_sot_inner_outputs
157- )
158- ],
159- [],
211+ new_carry ["mit_sot" ] = to_carry_storage (
212+ mit_sot_inner_outputs , carry ["mit_sot" ], step , input_taps ["mit_sot" ]
160213 )
161214
162215 if "sit_sot" in inner_output_idx :
163216 sit_sot_inner_outputs = [
164217 inner_outputs [idx ] for idx in inner_output_idx ["sit_sot" ]
165218 ]
166- new_carry ["sit_sot" ] = sum (
167- [
168- convert (step , carry_element , inner_outputs_element )
169- for (convert , carry_element , inner_outputs_element ) in zip (
170- sit_sot_to_carry , carry ["sit_sot" ], sit_sot_inner_outputs
171- )
172- ],
173- [],
219+ new_carry ["sit_sot" ] = to_carry_storage (
220+ sit_sot_inner_outputs , carry ["sit_sot" ], step , input_taps ["sit_sot" ]
174221 )
175222
176- new_carry ["step" ] = carry ["step" ] + 1
177-
178223 return new_carry
179224
180225 def scan_new_outputs (inner_outputs ):
181- """Create a new outer-output value from the inner-output value .
226+ """Create a new outer-output value from the outputs of the inner function .
182227
183228 Outer-outputs are ordered as follows:
184229 - mit-mot-outputs
@@ -210,7 +255,7 @@ def scan_new_outputs(inner_outputs):
210255
211256 def body_fn (carry , x ):
212257 inner_in_args = scan_inner_in_args (carry , x )
213- inner_outputs = scan_inner_func (* inner_in_args )
258+ inner_outputs = scan_inner_fn (* inner_in_args )
214259 new_carry = scan_new_carry (carry , inner_outputs )
215260 outer_outputs = scan_new_outputs (inner_outputs )
216261 return new_carry , outer_outputs
0 commit comments