@@ -35,14 +35,232 @@ def parse_outer_inputs(outer_inputs):
3535
3636 return outer_in
3737
38+ breakpoint ()
3839 if op .info .as_while :
39- raise NotImplementedError ("While loops are not supported in the JAX backend." )
40+ # The inner function returns a boolean as the last value.
41+ return make_jax_while_fn (scan_inner_fn , parse_outer_inputs , input_taps )
4042 else :
41- return make_jax_scan_fn (
42- scan_inner_fn ,
43- parse_outer_inputs ,
44- input_taps ,
45- )
43+ return make_jax_scan_fn (scan_inner_fn , parse_outer_inputs , input_taps )
44+
45+
46+ def make_jax_while_fn (
47+ scan_inner_fn : Callable ,
48+ parse_outer_inputs : Callable [[TensorVariable ], Dict [str , List [TensorVariable ]]],
49+ input_taps : Dict ,
50+ ):
51+ """Create a `jax.lax.while_loop` function to perform `Scan` computations when it
52+ is used as while loop.
53+
54+ `jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
55+ must return a value of the same type (Pytree structure, shape and dtype of
56+ the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
57+ current value and returns a boolean that indicates whether to keep iterating
58+ or not.
59+
60+ The JAX `while_loop` needs to perform the following operations:
61+
62+ 1. Extract the inner-inputs;
63+ 2. Build the initial carry value;
64+ 3. Inside the loop:
65+ 1. `carry` -> inner-inputs;
66+ 2. inner-outputs -> `carry`
67+ 4. Post-process the `carry` storage and return outputs
68+ """
69+
70+ def build_while_carry (outer_in ):
71+ """Build the inputs to `jax.lax.scan` from the outer-inputs."""
72+ init_carry = {
73+ name : outer_in [name ]
74+ for name in ["mit_sot" , "sit_sot" , "nit_sot" , "shared" , "sequences" , "non_sequences" ]
75+ }
76+ init_carry ["step" ] = 0
77+ return init_carry
78+
79+ def build_inner_outputs_map (outer_in ):
80+ """Map the inner-output variables to their position in the tuple returned by the inner function.
81+
82+ TODO: Copied from the scan builder
83+
84+ Inner-outputs are ordered as follow:
85+ - mit-mot-outputs
86+ - mit-sot-outputs
87+ - sit-sot-outputs
88+ - nit-sots (no carry)
89+ - shared-outputs
90+ [+ while-condition]
91+
92+ """
93+ inner_outputs_names = ["mit_sot" , "sit_sot" , "nit_sot" , "shared" ]
94+
95+ offset = 0
96+ inner_output_idx = defaultdict (list )
97+ for name in inner_outputs_names :
98+ num_outputs = len (outer_in [name ])
99+ for i in range (num_outputs ):
100+ inner_output_idx [name ].append (offset + i )
101+ offset += num_outputs
102+
103+ return inner_output_idx
104+
105+ def from_carry_storage (carry , step , input_taps ):
106+ """Fetch the inner inputs from the values stored in the carry array.
107+
108+ `Scan` passes storage arrays as inputs, which are then read from and
109+ updated in the loop body. At each step we need to read from this array
110+ the inputs that will be passed to the inner function.
111+
112+ This mechanism is necessary because we handle multiple-input taps within
113+ the `scan` instead of letting users manage the memory in the use cases
114+ where this is necessary.
115+
116+ TODO: Copied from the scan builder
117+
118+ """
119+
120+ def fetch (carry , step , offset ):
121+ return carry [step + offset ]
122+
123+ inner_inputs = []
124+ for taps , carry_element in zip (input_taps , carry ):
125+ storage_size = - min (taps )
126+ offsets = [storage_size + tap for tap in taps ]
127+ inner_inputs .append (
128+ [fetch (carry_element , step , offset ) for offset in offsets ]
129+ )
130+
131+ return sum (inner_inputs , [])
132+
133+ def to_carry_storage (inner_outputs , carry , step , input_taps ):
134+ """Create the new carry array from the inner output
135+
136+ `Scan` passes storage arrays as inputs, which are then read from and
137+ updated in the loop body. At each step we need to update this array
138+ with the outputs of the inner function
139+
140+ TODO: Copied from the scan builder
141+
142+ """
143+ new_carry_element = []
144+ for taps , carry_element , output in zip (input_taps , carry , inner_outputs ):
145+ new_carry_element .append (
146+ [carry_element .at [step - tap ].set (output ) for tap in taps ]
147+ )
148+
149+ return sum (new_carry_element , [])
150+
151+ def while_loop (* outer_inputs ):
152+
153+ outer_in = parse_outer_inputs (outer_inputs )
154+ init_carry = build_while_carry (outer_in )
155+ inner_output_idx = build_inner_outputs_map (outer_in )
156+
157+ def inner_inputs_from_carry (carry ):
158+ """Get inner-inputs from the arguments passed to the `jax.lax.while_loop` body function.
159+
160+ Inner-inputs are ordered as follows:
161+ - sequences
162+ - mit-mot inputs
163+ - mit-sot inputs
164+ - sit-sot inputs
165+ - shared-inputs
166+ - non-sequences
167+
168+ """
169+ current_step = carry ["step" ]
170+
171+ inner_in_seqs = carry ["sequences" ][current_step ]
172+ inner_in_mit_sot = from_carry_storage (
173+ carry ["mit_sot" ], current_step , input_taps ["mit_sot" ]
174+ )
175+ inner_in_sit_sot = from_carry_storage (
176+ carry ["sit_sot" ], current_step , input_taps ["sit_sot" ]
177+ )
178+ inner_in_shared = carry .get ("shared" , [])
179+ inner_in_non_sequences = carry .get ("non_sequences" , [])
180+
181+ return sum (
182+ [
183+ inner_in_seqs ,
184+ inner_in_mit_sot ,
185+ inner_in_sit_sot ,
186+ inner_in_shared ,
187+ inner_in_non_sequences ,
188+ ],
189+ [],
190+ )
191+
192+ def carry_from_inner_outputs (inner_outputs ):
193+ step = carry ["step" ]
194+ new_carry = {
195+ "mit_sot" : [],
196+ "sit_sot" : [],
197+ "nit-sot" : [],
198+ "shared" : [],
199+ "step" : step + 1 ,
200+ "sequences" : carry ["sequences" ],
201+ "non_sequences" : carry ["non_sequences" ],
202+ }
203+
204+ if "shared" in inner_output_idx :
205+ shared_inner_outputs = [
206+ inner_outputs [idx ] for idx in inner_output_idx ["shared" ]
207+ ]
208+ new_carry ["shared" ] = shared_inner_outputs
209+
210+ if "mit_sot" in inner_output_idx :
211+ mit_sot_inner_outputs = [
212+ inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]
213+ ]
214+ new_carry ["mit_sot" ] = to_carry_storage (
215+ mit_sot_inner_outputs , carry ["mit_sot" ], step , input_taps ["mit_sot" ]
216+ )
217+
218+ if "sit_sot" in inner_output_idx :
219+ sit_sot_inner_outputs = [
220+ inner_outputs [idx ] for idx in inner_output_idx ["sit_sot" ]
221+ ]
222+ new_carry ["sit_sot" ] = to_carry_storage (
223+ sit_sot_inner_outputs , carry ["sit_sot" ], step , input_taps ["sit_sot" ]
224+ )
225+ if "nit_sot" in inner_output_idx :
226+ nit_sot_inner_outputs = [
227+ inner_outputs [idx ] for idx in inner_output_idx ["nit_sot" ]
228+ ]
229+ new_carry ["nit_sot" ] = to_carry_storage (
230+ nit_sot_inner_outputs , carry ["nit_sot" ], step , input_taps ["nit_sot" ]
231+ )
232+
233+ return new_carry
234+
235+ def cond_fn (carry ):
236+ # The inner-function of `Scan` returns a boolean as the last
237+ # value. This needs to be included in `carry`.
238+ # TODO: Will it return `False` if the number of steps is exceeded?
239+ return carry ["do_continue" ]
240+
241+ def body_fn (carry ):
242+ inner_inputs = inner_inputs_from_carry (carry )
243+ inner_outputs = scan_inner_fn (* inner_inputs )
244+ new_carry = carry_from_inner_outputs (inner_outputs )
245+ return new_carry
246+
247+ # TODO
248+ # The `Scan` implementation in the C backend will execute the
249+ # function once before checking the termination condition, while
250+ # `jax.lax.while_loop` checks the condition first. We thus need to call
251+ # `body_fn` once before calling `jax.lax.while_loop`. This allows us,
252+ # along with `n_steps`, to build the storage array for the `nit-sot`s
253+ # since there is no way to know their shape and dtype before executing
254+ # the function.
255+ carry = body_fn (init_carry )
256+ carry = jax .lax .while_loop (body_fn , cond_fn , carry )
257+
258+ # TODO: Post-process the storage arrays
259+ outer_outputs = carry
260+
261+ return outer_outputs
262+
263+ return while_loop
46264
47265
48266def make_jax_scan_fn (
@@ -58,7 +276,8 @@ def make_jax_scan_fn(
58276 stacked to the previous outputs. We use this to our advantage to build
59277 `Scan` outputs without having to post-process the storage arrays.
60278
61- The JAX scan function needs to perform the following operations:
279+ The JAX `scan` function needs to perform the following operations:
280+
62281 1. Extract the inner-inputs;
63282 2. Build the initial `carry` and `sequence` values;
64283 3. Inside the loop:
@@ -151,7 +370,6 @@ def scan(*outer_inputs):
151370 outer_in = parse_outer_inputs (outer_inputs )
152371 n_steps , sequences , init_carry = build_jax_scan_inputs (outer_in )
153372 inner_output_idx = build_inner_outputs_map (outer_in )
154-
155373 def scan_inner_in_args (carry , x ):
156374 """Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
157375
@@ -265,11 +483,11 @@ def body_fn(carry, x):
265483 )
266484
267485 shared_output = tuple (last_carry ["shared" ])
268- results = results + shared_output
486+ outer_outputs = results + shared_output
269487
270- if len (results ) == 1 :
271- return results [0 ]
488+ if len (outer_outputs ) == 1 :
489+ return outer_outputs [0 ]
272490
273- return results
491+ return outer_outputs
274492
275493 return scan
0 commit comments