22from typing import Callable , Dict , List
33
44import jax
5+ import jax .numpy as jnp
56
67from aesara .link .jax .dispatch .basic import jax_funcify
78from aesara .scan .op import Scan
@@ -14,6 +15,7 @@ def jax_funcify_Scan(op, node, **kwargs):
1415 input_taps = {
1516 "mit_sot" : op .info .mit_sot_in_slices ,
1617 "sit_sot" : op .info .sit_sot_in_slices ,
18+ "nit_sot" : op .info .sit_sot_in_slices ,
1719 }
1820
1921 # Outer-inputs are the inputs to the `Scan` apply node, built from the
@@ -36,13 +38,258 @@ def parse_outer_inputs(outer_inputs):
3638 return outer_in
3739
3840 if op .info .as_while :
39- raise NotImplementedError ("While loops are not supported in the JAX backend." )
41+ # We can only compile a `Scan` node that acts as a `while` loop to JAX
42+ # if only the last computed value is ever used in the outer function.
43+ # TODO: Determine if that's the case
44+ # TODO: Rewrite the graph if that's the case
45+ # TODO: Implement a simple `while` loop that returns the last step
46+ return make_jax_while_fn (scan_inner_fn , parse_outer_inputs , input_taps )
4047 else :
41- return make_jax_scan_fn (
42- scan_inner_fn ,
43- parse_outer_inputs ,
44- input_taps ,
45- )
48+ return make_jax_scan_fn (scan_inner_fn , parse_outer_inputs , input_taps )
49+
50+
51+ def make_jax_while_fn (
52+ scan_inner_fn : Callable ,
53+ parse_outer_inputs : Callable [[TensorVariable ], Dict [str , List [TensorVariable ]]],
54+ input_taps : Dict ,
55+ ):
56+ """Create a `jax.lax.while_loop` function to perform `Scan` computations when it
57+ is used as while loop.
58+
59+ `jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
60+ must return a value of the same type (Pytree structure, shape and dtype of
61+ the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
62+ current value and returns a boolean that indicates whether to keep iterating
63+ or not.
64+
65+ The JAX `while_loop` needs to perform the following operations:
66+
67+ 1. Extract the inner-inputs;
68+ 2. Build the initial carry value;
69+ 3. Inside the loop:
70+ 1. `carry` -> inner-inputs;
71+ 2. inner-outputs -> `carry`
72+ 4. Post-process the `carry` storage and return outputs
73+ """
74+
75+ def build_while_carry (outer_in ):
76+ """Build the inputs to `jax.lax.scan` from the outer-inputs."""
77+ init_carry = {
78+ "mit_sot" : [],
79+ "mit_sot_storage" : outer_in ["mit_sot" ],
80+ "sit_sot" : [],
81+ "sit_sot_storage" : outer_in ["sit_sot" ],
82+ "shared" : outer_in ["shared" ],
83+ "sequences" : outer_in ["sequences" ],
84+ "non_sequences" : outer_in ["non_sequences" ],
85+ }
86+ init_carry ["step" ] = 0
87+ init_carry ["do_stop" ] = False
88+ return init_carry
89+
90+ def build_inner_outputs_map (outer_in ):
91+ """Map the inner-output variables to their position in the tuple returned by the inner function.
92+
93+ TODO: Copied from the scan builder
94+
95+ Inner-outputs are ordered as follow:
96+ - mit-mot-outputs
97+ - mit-sot-outputs
98+ - sit-sot-outputs
99+ - nit-sots (no carry)
100+ - shared-outputs
101+ [+ while-condition]
102+
103+ """
104+ inner_outputs_names = ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
105+
106+ offset = 0
107+ inner_output_idx = defaultdict (list )
108+ for name in inner_outputs_names :
109+ num_outputs = len (outer_in [name ])
110+ for i in range (num_outputs ):
111+ inner_output_idx [name ].append (offset + i )
112+ offset += num_outputs
113+
114+ return inner_output_idx
115+
116+ def from_carry_storage (carry , step , input_taps ):
117+ """Fetch the inner inputs from the values stored in the carry array.
118+
119+ `Scan` passes storage arrays as inputs, which are then read from and
120+ updated in the loop body. At each step we need to read from this array
121+ the inputs that will be passed to the inner function.
122+
123+ This mechanism is necessary because we handle multiple-input taps within
124+ the `scan` instead of letting users manage the memory in the use cases
125+ where this is necessary.
126+
127+ TODO: Copied from the scan builder
128+
129+ """
130+
131+ def fetch (carry , step , offset ):
132+ return carry [step + offset ]
133+
134+ inner_inputs = []
135+ for taps , carry_element in zip (input_taps , carry ):
136+ storage_size = - min (taps )
137+ offsets = [storage_size + tap for tap in taps ]
138+ inner_inputs .append (
139+ [fetch (carry_element , step , offset ) for offset in offsets ]
140+ )
141+
142+ return sum (inner_inputs , [])
143+
144+ def to_carry_storage (inner_outputs , carry , step , input_taps ):
145+ """Create the new carry array from the inner output
146+
147+ `Scan` passes storage arrays as inputs, which are then read from and
148+ updated in the loop body. At each step we need to update this array
149+ with the outputs of the inner function
150+
151+ TODO: Copied from the scan builder
152+
153+ """
154+ new_carry_element = []
155+ for taps , carry_element , output in zip (input_taps , carry , inner_outputs ):
156+ new_carry_element .append (
157+ [carry_element .at [step - tap ].set (output ) for tap in taps ]
158+ )
159+
160+ return sum (new_carry_element , [])
161+
162+ def while_loop (* outer_inputs ):
163+
164+ outer_in = parse_outer_inputs (outer_inputs )
165+ init_carry = build_while_carry (outer_in )
166+ inner_output_idx = build_inner_outputs_map (outer_in )
167+
168+ def inner_inputs_from_carry (carry ):
169+ """Get inner-inputs from the arguments passed to the `jax.lax.while_loop` body function.
170+
171+ Inner-inputs are ordered as follows:
172+ - sequences
173+ - mit-mot inputs
174+ - mit-sot inputs
175+ - sit-sot inputs
176+ - shared-inputs
177+ - non-sequences
178+
179+ """
180+ current_step = carry ["step" ]
181+
182+ inner_in_mit_sot = from_carry_storage (
183+ carry ["mit_sot_storage" ], current_step , input_taps ["mit_sot" ]
184+ )
185+ inner_in_sit_sot = from_carry_storage (
186+ carry ["sit_sot_storage" ], current_step , input_taps ["sit_sot" ]
187+ )
188+ inner_in_shared = carry .get ("shared" , [])
189+ inner_in_non_sequences = carry .get ("non_sequences" , [])
190+
191+ return sum (
192+ [
193+ inner_in_mit_sot ,
194+ inner_in_sit_sot ,
195+ inner_in_shared ,
196+ inner_in_non_sequences ,
197+ ],
198+ [],
199+ )
200+
201+ def carry_from_inner_outputs (carry , inner_outputs ):
202+ step = carry ["step" ]
203+ new_carry = {
204+ "mit_sot" : [],
205+ "sit_sot" : [],
206+ "sit_sot_storage" : [],
207+ "nit_sot" : [],
208+ "mit_sot_storage" : [],
209+ "shared" : [],
210+ "step" : step + 1 ,
211+ "sequences" : carry ["sequences" ],
212+ "non_sequences" : carry ["non_sequences" ],
213+ "do_stop" : inner_outputs [- 1 ],
214+ }
215+
216+ if "shared" in inner_output_idx :
217+ shared_inner_outputs = [
218+ inner_outputs [idx ] for idx in inner_output_idx ["shared" ]
219+ ]
220+ new_carry ["shared" ] = shared_inner_outputs
221+
222+ if "mit_sot" in inner_output_idx :
223+ mit_sot_inner_outputs = [
224+ inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]
225+ ]
226+ new_carry ["mit_sot" ] = mit_sot_inner_outputs
227+ new_carry ["mit_sot_storage" ] = to_carry_storage (
228+ mit_sot_inner_outputs ,
229+ carry ["mit_sot_storage" ],
230+ step ,
231+ input_taps ["mit_sot" ],
232+ )
233+
234+ if "sit_sot" in inner_output_idx :
235+ sit_sot_inner_outputs = [
236+ inner_outputs [idx ] for idx in inner_output_idx ["sit_sot" ]
237+ ]
238+ new_carry ["sit_sot" ] = sit_sot_inner_outputs
239+ new_carry ["sit_sot_storage" ] = to_carry_storage (
240+ sit_sot_inner_outputs ,
241+ carry ["sit_sot_storage" ],
242+ step ,
243+ input_taps ["sit_sot" ],
244+ )
245+
246+ if "nit_sot" in inner_output_idx :
247+ nit_sot_inner_outputs = [
248+ inner_outputs [idx ] for idx in inner_output_idx ["nit_sot" ]
249+ ]
250+ new_carry ["nit_sot" ] = nit_sot_inner_outputs
251+
252+ return new_carry
253+
254+ def cond_fn (carry ):
255+ # The inner-function of `Scan` returns a boolean as the last
256+ # value. This needs to be included in `carry`.
257+ # TODO: Will it return `False` if the number of steps is exceeded?
258+ return ~ carry ["do_stop" ]
259+
260+ def body_fn (carry ):
261+ inner_inputs = inner_inputs_from_carry (carry )
262+ inner_outputs = scan_inner_fn (* inner_inputs )
263+ new_carry = carry_from_inner_outputs (carry , inner_outputs )
264+ return new_carry
265+
266+ # The `Scan` implementation in the C backend will execute the
267+ # function once before checking the termination condition, while
268+ # `jax.lax.while_loop` checks the condition first. We thus need to call
269+ # `body_fn` once before calling `jax.lax.while_loop`. This allows us,
270+ # along with `n_steps`, to build the storage array for the `nit-sot`s
271+ # since there is no way to know their shape and dtype before executing
272+ # the function.
273+ inner_inputs = inner_inputs_from_carry (init_carry )
274+ inner_outputs = scan_inner_fn (* inner_inputs )
275+ carry = carry_from_inner_outputs (init_carry , inner_outputs )
276+ carry = jax .lax .while_loop (cond_fn , body_fn , carry )
277+
278+ # Post-process the storage arrays
279+ # We make sure that the outputs are not scalars in case an array
280+ # is expected downstream since `Scan` is supposed to always return arrays
281+ carry ["sit_sot" ] = [jnp .atleast_1d (element ) for element in carry ["sit_sot" ]]
282+ carry ["mit_sot" ] = [jnp .atleast_1d (element ) for element in carry ["mit_sot" ]]
283+ carry ["nit_not" ] = [jnp .atleast_1d (element ) for element in carry ["nit_sot" ]]
284+
285+ outer_outputs = ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
286+ results = sum ([carry [output ] for output in outer_outputs ], [])
287+ if len (results ) == 1 :
288+ return results [0 ]
289+ else :
290+ return results
291+
292+ return while_loop
46293
47294
48295def make_jax_scan_fn (
@@ -58,7 +305,8 @@ def make_jax_scan_fn(
58305 stacked to the previous outputs. We use this to our advantage to build
59306 `Scan` outputs without having to post-process the storage arrays.
60307
61- The JAX scan function needs to perform the following operations:
308+ The JAX `scan` function needs to perform the following operations:
309+
62310 1. Extract the inner-inputs;
63311 2. Build the initial `carry` and `sequence` values;
64312 3. Inside the loop:
@@ -265,11 +513,11 @@ def body_fn(carry, x):
265513 )
266514
267515 shared_output = tuple (last_carry ["shared" ])
268- results = results + shared_output
516+ outer_outputs = results + shared_output
269517
270- if len (results ) == 1 :
271- return results [0 ]
518+ if len (outer_outputs ) == 1 :
519+ return outer_outputs [0 ]
272520
273- return results
521+ return outer_outputs
274522
275523 return scan
0 commit comments