@@ -177,13 +177,50 @@ def scan_new_carry(carry, inner_outputs):
177177
178178 return new_carry
179179
180+ def scan_new_outputs (inner_outputs ):
181+ """Create a new outer-output value from the inner-output value.
182+
183+ Outer-outputs are ordered as follows:
184+ - mit-mot-outputs
185+ - mit-sot-outputs
186+ - sit-sot-outputs
187+ - nit-sots
188+ - shared-outputs
189+
190+ The shared output corresponds to the last value found in the last
191+ carry value returned by `jax.lax.scan`. It is thus not returned in
192+ the body function.
193+
194+ """
195+ outer_outputs = []
196+ if "mit_sot" in inner_output_idx :
197+ outer_outputs .append (
198+ [inner_outputs [idx ] for idx in inner_output_idx ["mit_sot" ]]
199+ )
200+ if "sit_sot" in inner_output_idx :
201+ outer_outputs .append (
202+ [inner_outputs [idx ] for idx in inner_output_idx ["sit_sot" ]]
203+ )
204+ if "nit_sot" in inner_output_idx :
205+ outer_outputs .append (
206+ [inner_outputs [idx ] for idx in inner_output_idx ["nit_sot" ]]
207+ )
208+
209+ return tuple (sum (outer_outputs , []))
210+
180211 def body_fn (carry , x ):
181212 inner_in_args = scan_inner_in_args (carry , x )
182213 inner_outputs = scan_inner_func (* inner_in_args )
183214 new_carry = scan_new_carry (carry , inner_outputs )
184- return new_carry , inner_outputs
215+ outer_outputs = scan_new_outputs (inner_outputs )
216+ return new_carry , outer_outputs
217+
218+ last_carry , results = jax .lax .scan (
219+ body_fn , init_carry , sequences , length = n_steps
220+ )
185221
186- _ , results = jax .lax .scan (body_fn , init_carry , sequences , length = n_steps )
222+ shared_output = tuple (last_carry ["shared" ])
223+ results = results + shared_output
187224
188225 if len (results ) == 1 :
189226 return results [0 ]
0 commit comments