Skip to content

ENH: Extend linear algebra graph rewrites to cover additional inverse and matmul patterns #1893

@aman-coder03

Description

@aman-coder03

Before

Currently, some expressions involving matrix inverse and multiplication may appear in computation graphs, such as:
    pt.matmul(pt.linalg.inv(A), B)
    pt.matmul(A, pt.linalg.inv(B))
    pt.linalg.inv(A).T @ b

While PyTensor already includes rewrites like `inv(A) @ b → solve(A, b)` for certain Dot patterns, similar optimizations may not apply consistently across all equivalent matmul or transpose-based expressions.
As a result, graphs may still contain explicit matrix inverse operations that could otherwise be rewritten into more efficient solve-based formulations.

After

Extend the existing linear algebra rewrite system to detect and optimize additional equivalent patterns involving matrix inverse, transpose, and matrix multiplication.
For example, rewrite patterns such as:

    matmul(inv(A), B) → solve(A, B)
    matmul(A, inv(B)) → solve(B.T, A.T).T

These rewrites would improve numerical stability, reduce computational cost, and simplify the computation graph.

Context for the issue:

PyTensor already provides several linear algebra graph rewrites in pytensor/tensor/rewriting/linalg.py, including optimizations involving matrix inverse and solve operations.
Extending rewrite coverage to additional equivalent patterns involving matmul, transpose, and inverse would further improve graph optimization and performance for linear algebra workloads.
This enhancement would align with PyTensor’s existing rewrite infrastructure and support more comprehensive optimization of symbolic linear algebra expressions.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions