diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..558abb4460 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node): new_b = b else: return None + + # new condition to handle (1,1) @ (1,1) + if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1): + return [a * b] # Direct elementwise multiplication new_a = copy_stack_trace(a, new_a) new_b = copy_stack_trace(b, new_b)