Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 7df44a5

Browse files
Fix missing broadcast dimension sums in Elemwise, BroadcastTo gradients
Closes #1089
1 parent 9b7021c commit 7df44a5

File tree

4 files changed

+59
-28
lines changed

4 files changed

+59
-28
lines changed

aesara/tensor/elemwise.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,6 @@ def connection_pattern(self, node):
542542
return [[True for output in node.outputs] for ipt in node.inputs]
543543

544544
def L_op(self, inputs, outs, ograds):
545-
from aesara.tensor.math import sum as at_sum
546545

547546
# Compute grad with respect to broadcasted input
548547
rval = self._bgrad(inputs, outs, ograds)
@@ -573,18 +572,9 @@ def L_op(self, inputs, outs, ograds):
573572
if isinstance(rval[i].type, (NullType, DisconnectedType)):
574573
continue
575574

576-
# List of all the dimensions that are broadcastable for input[i] so
577-
# we can sum over them
578-
# TODO: only count dimensions that were effectively broadcasted
579-
to_sum = [
580-
j
581-
for j, bcast in enumerate(ipt.type.broadcastable)
582-
if bcast and not outs[0].broadcastable[j]
583-
]
584-
585-
if to_sum:
586-
sr = at_sum(rval[i], axis=to_sum, keepdims=True)
587-
rval[i] = sr
575+
rval[i] = aesara.tensor.extra_ops.sum_broadcastable_dims(
576+
rval[i], ipt.shape, outs[0].shape
577+
)
588578

589579
return rval
590580

aesara/tensor/extra_ops.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Collection
22
from functools import reduce
3-
from typing import Iterable, Set, Tuple, Union
3+
from typing import Iterable, Sequence, Set, Tuple, Union
44

55
import numpy as np
66
import numpy.core.numeric
@@ -1665,19 +1665,8 @@ def grad(self, inputs, outputs_gradients):
16651665

16661666
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)
16671667

1668-
# Determine the dimensions that were broadcast
1669-
_, static_shape = at.infer_static_shape(shape)
1670-
1671-
# TODO: This needs to be performed at run-time when static shape
1672-
# information isn't available.
1673-
bcast_sums = [
1674-
i
1675-
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
1676-
if a_s == 1 and s_s != 1
1677-
]
1678-
1679-
if bcast_sums:
1680-
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)
1668+
# Determine the dimensions that were broadcast and sum them
1669+
d_wrt_a = sum_broadcastable_dims(d_wrt_a, a.shape, shape[-a.ndim :])
16811670

16821671
return [d_wrt_a] + [
16831672
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
@@ -1804,6 +1793,33 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
18041793
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)
18051794

18061795

1796+
def sum_broadcastable_dims(
1797+
value: TensorVariable,
1798+
shape_1: Sequence[Variable],
1799+
shape_2: Sequence[Variable],
1800+
) -> TensorVariable:
1801+
"""Sum dimensions in `value` that are broadcasted between `shape_1` and `shape_2`."""
1802+
from aesara.ifelse import ifelse
1803+
1804+
for i, (s1, s2) in enumerate(zip(shape_1, shape_2)):
1805+
dummy_s1 = aes.get_scalar_type(dtype=s1.type.dtype)()
1806+
dummy_s2 = aes.get_scalar_type(dtype=s2.type.dtype)()
1807+
cond_op = Composite(
1808+
[dummy_s1, dummy_s2],
1809+
[
1810+
aesara.scalar.and_(
1811+
aesara.scalar.eq(dummy_s1, 1), aesara.scalar.neq(dummy_s2, 1)
1812+
)
1813+
],
1814+
)
1815+
value = ifelse(
1816+
cond_op(at.scalar_from_tensor(s1), at.scalar_from_tensor(s2)),
1817+
at_sum(value, axis=i, keepdims=True),
1818+
value,
1819+
)
1820+
return value
1821+
1822+
18071823
__all__ = [
18081824
"searchsorted",
18091825
"cumsum",

tests/tensor/test_elemwise.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import aesara
1010
import aesara.scalar as aes
1111
import tests.unittest_tools as utt
12-
from aesara.compile.mode import Mode
12+
from aesara.compile.mode import Mode, get_default_mode
1313
from aesara.configdefaults import config
1414
from aesara.graph.basic import Apply, Variable
1515
from aesara.graph.fg import FunctionGraph
@@ -889,6 +889,30 @@ def test_invalid_static_shape(self):
889889
):
890890
x + y
891891

892+
def test_grad_sum_bcast_input_dims(self):
893+
"""Make sure broadcasted dimensions in the gradients are summed when static shape information isn't available."""
894+
Y = matrix("Y")
895+
X = matrix("X")
896+
X_grad = aesara.grad((X + Y).sum(), wrt=X)
897+
898+
mode = get_default_mode().including("fast_run")
899+
900+
X_grad_fn = aesara.function([X, Y], X_grad, mode=mode)
901+
res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5)))
902+
assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]]))
903+
904+
# When the shapes are known at compile-time, the compiled graph should
905+
# simplify
906+
Y = tensor(np.float64, shape=(5, None), name="Y")
907+
X = tensor(np.float64, shape=(1, 5), name="X")
908+
X_grad = aesara.grad((X + Y).sum(), wrt=X)
909+
910+
X_grad_fn = aesara.function([X, Y], X_grad, mode=mode)
911+
res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5)))
912+
assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]]))
913+
914+
assert X_grad_fn.maker.fgraph.apply_nodes
915+
892916

893917
def test_not_implemented_elemwise_grad():
894918
# Regression test for unimplemented gradient in an Elemwise Op.

tests/tensor/test_extra_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,7 @@ def test_memory_leak(self):
13121312
[
13131313
[lambda x: broadcast_to(x, (1,)), (1,)],
13141314
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
1315+
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)],
13151316
[lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)],
13161317
[lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)],
13171318
],

0 commit comments

Comments
 (0)