14
14
register_stabilize ,
15
15
)
16
16
from pytensor .tensor .shape import Reshape
17
- from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedSubtensor , Subtensor
17
+ from pytensor .tensor .subtensor import (
18
+ AdvancedIncSubtensor ,
19
+ AdvancedSubtensor ,
20
+ Subtensor ,
21
+ indices_from_subtensor ,
22
+ )
18
23
19
24
20
25
@node_rewriter ([Blockwise ])
@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
216
221
217
222
Reshape is tricky to vectorize eagerly, because a graph like
218
223
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
219
- that must be vectorized before we arrize at the reshape operation.
224
+ that must be vectorized before we arrive at the reshape operation.
220
225
221
- For the square Reshape case, we must wait for all the intemediate
226
+ For the square Reshape case, we must wait for all the intermediate
222
227
operations to be lifted as Allocs
223
228
"""
224
229
if not isinstance (node .op .core_op , Reshape ):
@@ -234,6 +239,26 @@ def local_blockwise_reshape(fgraph, node):
234
239
return [new_out ]
235
240
236
241
242
+ @register_stabilize
243
+ @register_specialize
244
+ @node_rewriter ([Blockwise ])
245
+ def local_blockwise_of_subtensor (fgraph , node ):
246
+ """Rewrite Blockwise of Subtensor, where the only batch dimensions are the inputs."""
247
+ if not isinstance (node .op .core_op , Subtensor ):
248
+ return
249
+
250
+ x , * idxs = node .inputs
251
+ if not all (all (idx .type .broadcastable ) for idx in idxs ):
252
+ return
253
+
254
+ core_idxs = indices_from_subtensor (
255
+ [idx .squeeze () for idx in idxs ], node .op .core_op .idx_list
256
+ )
257
+ # Add empty slices for the batch dims
258
+ none_slices = (slice (None ),) * node .op .batch_ndim (node )
259
+ return [x [(* none_slices , * core_idxs )]]
260
+
261
+
237
262
@node_rewriter (tracks = [Blockwise ], inplace = True )
238
263
def blockwise_inplace (fgraph , node ):
239
264
blockwise_op = node .op
0 commit comments