-
Notifications
You must be signed in to change notification settings - Fork 155
Open
Description
Description
We're missing this simple rewrite:
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
x = pt.scalar("x")
out = pt.log(pt.gamma(x))
new_out = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
new_out.dprint()
Can be done easily with PatternNodeRewriter as in
pytensor/pytensor/tensor/rewriting/math.py
Lines 3646 to 3651 in 911c6a3
| local_polygamma_to_tri_gamma = PatternNodeRewriter( | |
| (polygamma, 1, "x"), | |
| (tri_gamma, "x"), | |
| allow_multiple_clients=True, | |
| name="local_polygamma_to_tri_gamma", | |
| ) |
We could also add rewrites for common combinatorics expressions like
naive_betaln = pt.log((pt.gamma(x) * pt.gamma(y)) / pt.gamma(x + y)
betaln = pt.gammaln(x) + pt.gammaln(y) - pt.gammaln(x + y)
pytensor/pytensor/tensor/special.py
Lines 799 to 804 in ad55b69
| def betaln(a, b): | |
| """ | |
| Log beta function. | |
| """ | |
| return gammaln(a) + gammaln(b) - gammaln(a + b) |
Or for log(poch):
pytensor/pytensor/tensor/special.py
Lines 767 to 772 in ad55b69
| def poch(z, m): | |
| """ | |
| Pochhammer symbol (rising factorial) function. | |
| """ | |
| return gamma(z + m) / gamma(z) |
For these more general cases we can probably use something more flexible than the PatternNodeRewriter. We want to apply as long as we know all the terms inside are factorials/gammas/exps (positive things that easily blow up). This is a narrow/easier subset of #177