@@ -48,19 +48,40 @@ def scan(*outer_inputs):
4848 outer_in_seqs = list (op .outer_seqs (outer_inputs ))
4949 outer_in_mit_mot = list (op .outer_mitmot (outer_inputs ))
5050 outer_in_mit_sot = list (op .outer_mitsot (outer_inputs ))
51+ outer_in_nit_sot = list (op .outer_nitsot (outer_inputs ))
5152 outer_in_sit_sot = list (op .outer_sitsot (outer_inputs ))
5253 outer_in_shared = list (op .outer_shared (outer_inputs ))
5354 outer_in_non_seqs = list (op .outer_non_seqs (outer_inputs ))
5455 if len (outer_in_mit_mot ):
5556 raise NotImplementedError ("mit-mot not supported" )
5657 if len (outer_in_mit_sot ):
5758 raise NotImplementedError ("mit-sot not supported" )
58- if len (outer_in_shared ):
59- raise NotImplementedError ("shared variables not supported" )
6059
61- init_carry = outer_in_sit_sot
60+ # These are the outer-inputs
6261 sequences = outer_in_seqs
6362 non_sequences = outer_in_non_seqs
63+ init_carry = {}
64+ for name , outputs in [
65+ ("sit_sot" , outer_in_sit_sot ),
66+ ("shared" , outer_in_shared ),
67+ ]:
68+ if len (outputs ) > 0 :
69+ init_carry [name ] = outputs
70+
71+ # We keep track of the kind of inner_outputs and their number
72+ from collections import defaultdict
73+
74+ offset = 0
75+ inner_output_idx = defaultdict (list )
76+ for name , outputs in [
77+ ("sit_sot" , outer_in_sit_sot ),
78+ ("nit_sot" , outer_in_nit_sot ),
79+ ("shared" , outer_in_shared ),
80+ ]:
81+ if len (outputs ) > 0 :
82+ for i in range (len (outputs )):
83+ inner_output_idx [name ].append (offset + i )
84+ offset += len (outputs )
6485
6586 def scan_inner_in_args (carry , x ):
6687 """Get inner-inputs from the arguments passed to the `jax.lax.scan` body function.
@@ -72,18 +93,22 @@ def scan_inner_in_args(carry, x):
7293 - sit-sot inputs
7394 - shared-inputs
7495 - non-sequences
96+
7597 """
7698 inner_in_seqs = x
7799 inner_in_sit_sot = sum (
78100 [
79101 convert (carry_element )
80102 for convert in sit_sot_from_carry
81- for carry_element in carry
103+ for carry_element in carry [ "sit_sot" ]
82104 ],
83105 [],
84106 )
107+ inner_in_shared = carry .get ("shared" , [])
85108
86- return sum ([inner_in_seqs , inner_in_sit_sot , non_sequences ], [])
109+ return sum (
110+ [inner_in_seqs , inner_in_sit_sot , inner_in_shared , non_sequences ], []
111+ )
87112
88113 def scan_new_carry (carry , inner_outputs ):
89114 """Create a new carry value from the inner-outputs.
@@ -97,27 +122,39 @@ def scan_new_carry(carry, inner_outputs):
97122 [+ while-condition]
98123
99124 """
100- carry = sum (
101- [
102- convert (carry_element , inner_outputs_element )
103- for convert in sit_sot_to_carry
104- for (carry_element , inner_outputs_element ) in zip (
105- carry , inner_outputs
106- )
107- ],
108- [],
109- )
110- return carry
125+ new_carry = {}
126+
127+ if "shared" in inner_output_idx :
128+ shared_inner_outputs = [
129+ inner_outputs [idx ] for idx in inner_output_idx ["shared" ]
130+ ]
131+ new_carry ["shared" ] = shared_inner_outputs
132+
133+ if "sit_sot" in inner_output_idx :
134+ sit_sot_inner_outputs = [
135+ inner_outputs [idx ] for idx in inner_output_idx ["sit_sot" ]
136+ ]
137+ new_carry ["sit_sot" ] = sum (
138+ [
139+ convert (carry_element , inner_outputs_element )
140+ for convert in sit_sot_to_carry
141+ for (carry_element , inner_outputs_element ) in zip (
142+ carry ["sit_sot" ], sit_sot_inner_outputs
143+ )
144+ ],
145+ [],
146+ )
147+
148+ return new_carry
111149
112150 def body_fn (carry , x ):
113151 inner_in_args = scan_inner_in_args (carry , x )
114152 inner_outputs = scan_inner_func (* inner_in_args )
115153 new_carry = scan_new_carry (carry , inner_outputs )
116- return new_carry , * inner_outputs
154+ return new_carry , inner_outputs
117155
118- print (init_carry )
119156 _ , results = jax .lax .scan (body_fn , init_carry , sequences , length = n_steps )
120157
121- return results
158+ return results [ 0 ]
122159
123160 return scan
0 commit comments