Skip to content

Missing some simple matric algebraic simplifications #1479

Open
@ricardoV94

Description

@ricardoV94

Description

Missing:

  • A@B + A@C = A@(B+C) (one less matmul)
  • s*A @ B = s*(A@B) (which can be done by a single gemm routine)
from pytensor.graph import rewrite_graph
import pytensor.tensor as pt

A,B,C = pt.matrices("ABC")
s = pt.scalar("s")

o1 = A@B + A@C
rewrite_graph(o1, include=("fast_run",), exclude=("inplace",)).dprint()

print()
o2 = (s*A) @ B
rewrite_graph(o2, include=("fast_run",), exclude=("inplace",)).dprint()
Gemm{no_inplace} [id A]
 ├─ Dot22 [id B]
 │  ├─ A [id C]
 │  └─ B [id D]
 ├─ 1.0 [id E]
 ├─ A [id C]
 ├─ C [id F]
 └─ 1.0 [id E]

Dot22 [id A]
 ├─ Mul [id B]
 │  ├─ ExpandDims{axes=[0, 1]} [id C]
 │  │  └─ s [id D]
 │  └─ A [id E]
 └─ B [id F]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions