-
Notifications
You must be signed in to change notification settings - Fork 179
Open
Labels
Description
Before
Currently, some expressions involving matrix inverse and multiplication may appear in computation graphs, such as:
pt.matmul(pt.linalg.inv(A), B)
pt.matmul(A, pt.linalg.inv(B))
pt.linalg.inv(A).T @ b
While PyTensor already includes rewrites like `inv(A) @ b → solve(A, b)` for certain Dot patterns, similar optimizations may not apply consistently across all equivalent matmul or transpose-based expressions.
As a result, graphs may still contain explicit matrix inverse operations that could otherwise be rewritten into more efficient solve-based formulations.After
Extend the existing linear algebra rewrite system to detect and optimize additional equivalent patterns involving matrix inverse, transpose, and matrix multiplication.
For example, rewrite patterns such as:
matmul(inv(A), B) → solve(A, B)
matmul(A, inv(B)) → solve(B.T, A.T).T
These rewrites would improve numerical stability, reduce computational cost, and simplify the computation graph.Context for the issue:
PyTensor already provides several linear algebra graph rewrites in pytensor/tensor/rewriting/linalg.py, including optimizations involving matrix inverse and solve operations.
Extending rewrite coverage to additional equivalent patterns involving matmul, transpose, and inverse would further improve graph optimization and performance for linear algebra workloads.
This enhancement would align with PyTensor’s existing rewrite infrastructure and support more comprehensive optimization of symbolic linear algebra expressions.
Reactions are currently unavailable