Skip to content

Commit 61a8581

Browse files
committed
Rewrite away blockwise Subtensor in gradient of Blockwise(Conv1d)
1 parent 42ca403 commit 61a8581

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

pytensor/tensor/rewriting/blockwise.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
register_stabilize,
1515
)
1616
from pytensor.tensor.shape import Reshape
17-
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
17+
from pytensor.tensor.subtensor import (
18+
AdvancedIncSubtensor,
19+
AdvancedSubtensor,
20+
Subtensor,
21+
indices_from_subtensor,
22+
)
1823

1924

2025
@node_rewriter([Blockwise])
@@ -216,9 +221,9 @@ def local_blockwise_reshape(fgraph, node):
216221
217222
Reshape is tricky to vectorize eagerly, because a graph like
218223
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
219-
that must be vectorized before we arrize at the reshape operation.
224+
that must be vectorized before we arrive at the reshape operation.
220225
221-
For the square Reshape case, we must wait for all the intemediate
226+
For the square Reshape case, we must wait for all the intremediate
222227
operations to be lifted as Allocs
223228
"""
224229
if not isinstance(node.op.core_op, Reshape):
@@ -234,6 +239,26 @@ def local_blockwise_reshape(fgraph, node):
234239
return [new_out]
235240

236241

242+
@register_stabilize
243+
@register_specialize
244+
@node_rewriter([Blockwise])
245+
def local_blockwise_of_subtensor(fgraph, node):
246+
"""Rewrite Blockwise of Subtensor, where the only batch dimensions are the inputs."""
247+
if not isinstance(node.op.core_op, Subtensor):
248+
return
249+
250+
x, *idxs = node.inputs
251+
if not all(all(idx.type.broadcastable) for idx in idxs):
252+
return
253+
254+
core_idxs = indices_from_subtensor(
255+
[idx.squeeze() for idx in idxs], node.op.core_op.idx_list
256+
)
257+
# Add empty slices for the batch dims
258+
none_slices = (slice(None),) * node.op.batch_ndim(node)
259+
return [x[(*none_slices, *core_idxs)]]
260+
261+
237262
@node_rewriter(tracks=[Blockwise], inplace=True)
238263
def blockwise_inplace(fgraph, node):
239264
blockwise_op = node.op

tests/tensor/signal/test_conv.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import pytest
55
from scipy.signal import convolve as scipy_convolve
66

7-
from pytensor import config, function
7+
from pytensor import config, function, grad
8+
from pytensor.graph import ancestors, rewrite_graph
89
from pytensor.tensor import matrix, vector
9-
from pytensor.tensor.signal.conv import convolve1d
10+
from pytensor.tensor.blockwise import Blockwise
11+
from pytensor.tensor.signal.conv import Conv1d, convolve1d
1012
from tests import unittest_tools as utt
1113

1214

@@ -60,3 +62,23 @@ def test_convolve1d_batch_same():
6062

6163
res = out.eval({x: x_test, y: y_test})
6264
assert res.shape == (2, 8)
65+
66+
67+
@pytest.mark.parametrize("mode", ("full", "valid", "same"))
68+
def test_convolve1d_batch_graph(mode):
69+
"""Test that we don't have slow Blockwise Subtensors in graph of a batched convolve1d"""
70+
x = matrix("x")
71+
y = matrix("y")
72+
out = convolve1d(x, y, mode=mode)
73+
grads = grad(out.sum(), wrt=[x, y])
74+
final_grads = rewrite_graph(
75+
grads, include=("ShapeOpt", "canonicalize", "stabilize", "specialize")
76+
)
77+
78+
blockwise_nodes = [
79+
var.owner
80+
for var in ancestors(final_grads)
81+
if var.owner is not None and isinstance(var.owner.op, Blockwise)
82+
]
83+
# Check any Blockwise are just Conv1d
84+
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)

0 commit comments

Comments
 (0)