Open
Description
Description
Missing:
A@B + A@C = A@(B+C)
(one less matmul)s*A @ B = s*(A@B)
(which can be done by a single gemm routine)
from pytensor.graph import rewrite_graph
import pytensor.tensor as pt
A,B,C = pt.matrices("ABC")
s = pt.scalar("s")
o1 = A@B + A@C
rewrite_graph(o1, include=("fast_run",), exclude=("inplace",)).dprint()
print()
o2 = (s*A) @ B
rewrite_graph(o2, include=("fast_run",), exclude=("inplace",)).dprint()
Gemm{no_inplace} [id A]
├─ Dot22 [id B]
│ ├─ A [id C]
│ └─ B [id D]
├─ 1.0 [id E]
├─ A [id C]
├─ C [id F]
└─ 1.0 [id E]
Dot22 [id A]
├─ Mul [id B]
│ ├─ ExpandDims{axes=[0, 1]} [id C]
│ │ └─ s [id D]
│ └─ A [id E]
└─ B [id F]