@@ -36,13 +36,61 @@ def parse_outer_inputs(outer_inputs):
3636 return outer_in
3737
3838 if op .info .as_while :
39- raise NotImplementedError ("While loops are not supported in the JAX backend." )
39+ # The inner function returns a boolean as the last value.
40+ return make_jax_while_fn (scan_inner_fn , parse_outer_inputs , input_taps )
4041 else :
41- return make_jax_scan_fn (
42- scan_inner_fn ,
43- parse_outer_inputs ,
44- input_taps ,
45- )
42+ return make_jax_scan_fn (scan_inner_fn , parse_outer_inputs , input_taps )
43+
44+
45+ def make_jax_while_fn (
46+ scan_inner_fn : Callable ,
47+ parse_outer_inputs : Callable [[TensorVariable ], Dict [str , List [TensorVariable ]]],
48+ input_taps : Dict ,
49+ ):
50+ """Create a `jax.lax.while_loop` function to perform `Scan` computations when it
51+ is used as while loop.
52+
53+ `jax.lax.while_loop` iterates by passing a value `carry` to a `body_fun` that
54+ must return a value of the same type (Pytree structure, shape and dtype of
55+ the leaves). Before calling `body_fn`, it calls `cond_fn` which takes the
56+ current value and returns a boolean that indicates whether to keep iterating
57+ or not.
58+
59+ The JAX `while_loop` needs to perform the following operations:
60+
61+ 1. Extract the inner-inputs;
62+ 2. Build the initial carry value;
63+ 3. Inside the loop:
64+ 1. `carry` -> inner-inputs;
65+ 2. inner-outputs -> `carry`
66+ 4. Post-process the `carry` storage and return outputs
67+ """
68+
69+ def build_while_carry (outer_in ):
70+ return outer_in
71+
72+ def while_loop (* outer_inputs ):
73+
74+ outer_in = parse_outer_inputs (outer_inputs )
75+ init_carry = build_while_carry (outer_in )
76+
77+ def cond_fn (carry ):
78+ # The inner-function of `Scan` returns a boolean as the last
79+ # value. This needs to be included in `carry`.
80+ # TODO: Will it return `False` if the number of steps is exceeded?
81+ return carry ["do_continue" ]
82+
83+ def body_fn (carry ):
84+ return carry
85+
86+ carry = jax .lax .while_loop (body_fn , cond_fn , init_carry )
87+
88+ # Here post-process the result
89+ outer_outputs = carry
90+
91+ return outer_outputs
92+
93+ return while_loop
4694
4795
4896def make_jax_scan_fn (
@@ -58,7 +106,8 @@ def make_jax_scan_fn(
58106 stacked to the previous outputs. We use this to our advantage to build
59107 `Scan` outputs without having to post-process the storage arrays.
60108
61- The JAX scan function needs to perform the following operations:
109+ The JAX `scan` function needs to perform the following operations:
110+
62111 1. Extract the inner-inputs;
63112 2. Build the initial `carry` and `sequence` values;
64113 3. Inside the loop:
@@ -265,11 +314,11 @@ def body_fn(carry, x):
265314 )
266315
267316 shared_output = tuple (last_carry ["shared" ])
268- results = results + shared_output
317+ outer_outputs = results + shared_output
269318
270- if len (results ) == 1 :
271- return results [0 ]
319+ if len (outer_outputs ) == 1 :
320+ return outer_outputs [0 ]
272321
273- return results
322+ return outer_outputs
274323
275324 return scan
0 commit comments