2323 graph_inputs ,
2424 io_toposort ,
2525 is_in_ancestors ,
26+ replace_nominals_with_dummies ,
2627)
2728from aesara .graph .destroyhandler import DestroyHandler
2829from aesara .graph .features import ReplaceValidate
@@ -82,6 +83,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
8283 """
8384 if not isinstance (node .op , Scan ):
8485 return False
86+
8587 op = node .op
8688 op_info = op .info
8789 # We only need to take care of sequences and other arguments
@@ -92,8 +94,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
9294 st += op_info .n_sit_sot
9395 st += op_info .n_shared_outs
9496
95- op_ins = op .inner_inputs
96- op_outs = op .inner_outputs
97+ op_ins , op_outs = replace_nominals_with_dummies (op .inner_inputs , op .inner_outputs )
9798
9899 # Corresponds to the initial states, which should stay untouched.
99100 # We put those variables aside, and put them back at the end.
@@ -189,6 +190,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
189190 allow_gc = op .allow_gc ,
190191 )
191192 nw_outs = nwScan (* nw_outer , return_list = True )
193+
192194 return dict ([("remove" , [node ])] + list (zip (node .outputs , nw_outs )))
193195 else :
194196 return False
@@ -207,7 +209,9 @@ def push_out_non_seq_scan(fgraph, node):
207209 if not isinstance (node .op , Scan ):
208210 return False
209211
210- node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
212+ node_inputs , node_outputs = replace_nominals_with_dummies (
213+ node .op .inner_inputs , node .op .inner_outputs
214+ )
211215
212216 local_fgraph_topo = io_toposort (node_inputs , node_outputs )
213217 local_fgraph_outs_set = set (node_outputs )
@@ -417,7 +421,9 @@ def push_out_seq_scan(fgraph, node):
417421 if not isinstance (node .op , Scan ):
418422 return False
419423
420- node_inputs , node_outputs = node .op .inner_inputs , node .op .inner_outputs
424+ node_inputs , node_outputs = replace_nominals_with_dummies (
425+ node .op .inner_inputs , node .op .inner_outputs
426+ )
421427
422428 local_fgraph_topo = io_toposort (node_inputs , node_outputs )
423429 local_fgraph_outs_set = set (node_outputs )
@@ -832,9 +838,10 @@ def push_out_add_scan(fgraph, node):
832838
833839 # Use `ScanArgs` to parse the inputs and outputs of scan for ease of
834840 # use
835- args = ScanArgs (
836- node . inputs , node . outputs , op .inner_inputs , op .inner_outputs , op . info
841+ inner_inputs , inner_outputs = replace_nominals_with_dummies (
842+ op .inner_inputs , op .inner_outputs
837843 )
844+ args = ScanArgs (node .inputs , node .outputs , inner_inputs , inner_outputs , op .info )
838845
839846 clients = {}
840847 local_fgraph_topo = io_toposort (
@@ -1694,6 +1701,8 @@ def merge(self, nodes):
16941701 inner_outs = [[] for nd in nodes ]
16951702 outer_outs = []
16961703
1704+ # inner_inputs, inner_outputs = replace_nominals_with_dummies(nd.op.inner_inputs, nd.op.inner_outputs)
1705+
16971706 def rename (ls , suffix ):
16981707 for k in ls :
16991708 if k .name :
@@ -1967,11 +1976,16 @@ def scan_merge_inouts(fgraph, node):
19671976 # Do a first pass to merge identical external inputs.
19681977 # Equivalent inputs will be stored in inp_equiv, then a new
19691978 # scan node created without duplicates.
1979+
1980+ inner_inputs , inner_outputs = replace_nominals_with_dummies (
1981+ node .op .inner_inputs , node .op .inner_outputs
1982+ )
1983+
19701984 a = ScanArgs (
19711985 node .inputs ,
19721986 node .outputs ,
1973- node . op . inner_inputs ,
1974- node . op . inner_outputs ,
1987+ inner_inputs ,
1988+ inner_outputs ,
19751989 node .op .info ,
19761990 )
19771991
@@ -2173,10 +2187,15 @@ def push_out_dot1_scan(fgraph, node):
21732187 # Note that this works when only you need X[-1] in the end
21742188 # and assumes dimshuffle are applied to vectors before calling dot
21752189 op = node .op
2176- sitsot_ins = op .inner_sitsot (op .inner_inputs )
2177- sitsot_outs = op .inner_sitsot_outs (op .inner_outputs )
2190+
2191+ inner_inputs , inner_outputs = replace_nominals_with_dummies (
2192+ op .inner_inputs , op .inner_outputs
2193+ )
2194+
2195+ sitsot_ins = op .inner_sitsot (inner_inputs )
2196+ sitsot_outs = op .inner_sitsot_outs (inner_outputs )
21782197 outer_sitsot = op .outer_sitsot_outs (node .outputs )
2179- seqs = op .inner_seqs (op . inner_inputs )
2198+ seqs = op .inner_seqs (inner_inputs )
21802199 for inp , out , outer_out in zip (sitsot_ins , sitsot_outs , outer_sitsot ):
21812200
21822201 if (
@@ -2218,23 +2237,23 @@ def push_out_dot1_scan(fgraph, node):
22182237 # First let us split all arguments according to their
22192238 # corresponding categories
22202239
2221- inner_seqs = op .inner_seqs (op . inner_inputs )
2240+ inner_seqs = op .inner_seqs (inner_inputs )
22222241 outer_seqs = op .outer_seqs (node .inputs )
2223- inner_mitmot = op .inner_mitmot (op . inner_inputs )
2242+ inner_mitmot = op .inner_mitmot (inner_inputs )
22242243 outer_mitmot = op .outer_mitmot (node .inputs )
2225- inner_mitmot_outs = op .inner_mitmot_outs (op . inner_outputs )
2226- inner_mitsot = op .inner_mitsot (op . inner_inputs )
2244+ inner_mitmot_outs = op .inner_mitmot_outs (inner_outputs )
2245+ inner_mitsot = op .inner_mitsot (inner_inputs )
22272246 outer_mitsot = op .outer_mitsot (node .inputs )
2228- inner_mitsot_outs = op .inner_mitsot_outs (op . inner_outputs )
2229- inner_sitsot = op .inner_sitsot (op . inner_inputs )
2247+ inner_mitsot_outs = op .inner_mitsot_outs (inner_outputs )
2248+ inner_sitsot = op .inner_sitsot (inner_inputs )
22302249 outer_sitsot = op .outer_sitsot (node .inputs )
2231- inner_sitsot_outs = op .inner_sitsot_outs (op . inner_outputs )
2250+ inner_sitsot_outs = op .inner_sitsot_outs (inner_outputs )
22322251 outer_nitsot = op .outer_nitsot (node .inputs )
2233- inner_nitsot_outs = op .inner_nitsot_outs (op . inner_outputs )
2234- inner_shared = op .inner_shared (op . inner_inputs )
2252+ inner_nitsot_outs = op .inner_nitsot_outs (inner_outputs )
2253+ inner_shared = op .inner_shared (inner_inputs )
22352254 outer_shared = op .outer_shared (node .inputs )
2236- inner_shared_outs = op .inner_shared_outs (op . inner_outputs )
2237- inner_non_seqs = op .inner_non_seqs (op . inner_inputs )
2255+ inner_shared_outs = op .inner_shared_outs (inner_outputs )
2256+ inner_non_seqs = op .inner_non_seqs (inner_inputs )
22382257 outer_non_seqs = op .outer_non_seqs (node .inputs )
22392258
22402259 new_info = dataclasses .replace (
0 commit comments