Skip to content

Commit 70da029

Browse files
committed
Decompose Tridiagonal Solve into core steps
1 parent 148477c commit 70da029

File tree

8 files changed

+434
-30
lines changed

8 files changed

+434
-30
lines changed

pytensor/compile/mode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,9 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
477477
"fusion",
478478
"inplace",
479479
"scan_save_mem_prealloc",
480+
# There are specific variants for the LU decompositions supported by JAX
481+
"reuse_lu_decomposition_multiple_solves",
482+
"scan_split_non_sequence_lu_decomposition_solve",
480483
],
481484
),
482485
)

pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy import ndarray
77
from scipy import linalg
88

9+
from pytensor.link.numba.dispatch import numba_funcify
910
from pytensor.link.numba.dispatch.basic import numba_njit
1011
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1112
_LAPACK,
@@ -20,6 +21,10 @@
2021
_solve_check,
2122
_trans_char_to_int,
2223
)
24+
from pytensor.tensor._linalg.solve.tridiagonal import (
25+
LUFactorTridiagonal,
26+
SolveLUFactorTridiagonal,
27+
)
2328

2429

2530
@numba_njit
@@ -297,3 +302,48 @@ def impl(
297302
return X
298303

299304
return impl
305+
306+
307+
@numba_funcify.register(LUFactorTridiagonal)
308+
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
309+
overwrite_dl = op.overwrite_dl
310+
overwrite_d = op.overwrite_d
311+
overwrite_du = op.overwrite_du
312+
313+
@numba_njit(cache=False)
314+
def lu_factor_tridiagonal(dl, d, du):
315+
if not overwrite_dl:
316+
dl = dl.copy()
317+
if not overwrite_d:
318+
d = d.copy()
319+
if not overwrite_du:
320+
du = du.copy()
321+
322+
dl, d, du, du2, ipiv, _ = _gttrf(dl, d, du)
323+
return dl, d, du, du2, ipiv
324+
325+
return lu_factor_tridiagonal
326+
327+
328+
@numba_funcify.register(SolveLUFactorTridiagonal)
329+
def numba_funcify_SolveLUFactorTridiagonal(
330+
op: SolveLUFactorTridiagonal, node, **kwargs
331+
):
332+
overwrite_b = op.overwrite_b
333+
transposed = op.transposed
334+
335+
@numba_njit(cache=False)
336+
def solve_lu_factor_tridiagonal(dl, d, du, du2, ipiv, b):
337+
x, _ = _gttrs(
338+
dl,
339+
d,
340+
du,
341+
du2,
342+
ipiv,
343+
b,
344+
overwrite_b=overwrite_b,
345+
trans=transposed,
346+
)
347+
return x
348+
349+
return solve_lu_factor_tridiagonal

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from collections.abc import Container
22
from copy import copy
33

4+
from pytensor.compile import optdb
45
from pytensor.graph import Constant, graph_inputs
56
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
67
from pytensor.scan.op import Scan
78
from pytensor.scan.rewriting import scan_seqopt1
9+
from pytensor.tensor._linalg.solve.tridiagonal import (
10+
tridiagonal_lu_factor,
11+
tridiagonal_lu_solve,
12+
)
813
from pytensor.tensor.basic import atleast_Nd
914
from pytensor.tensor.blockwise import Blockwise
1015
from pytensor.tensor.elemwise import DimShuffle
@@ -17,18 +22,32 @@
1722
def decompose_A(A, assume_a, check_finite):
1823
if assume_a == "gen":
1924
return lu_factor(A, check_finite=check_finite)
25+
elif assume_a == "tridiagonal":
26+
# We didn't implement check_finite for tridiagonal LU factorization
27+
return tridiagonal_lu_factor(A)
2028
else:
2129
raise NotImplementedError
2230

2331

2432
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
25-
if core_solve_op.assume_a == "gen":
33+
b_ndim = core_solve_op.b_ndim
34+
check_finite = core_solve_op.check_finite
35+
assume_a = core_solve_op.assume_a
36+
if assume_a == "gen":
2637
return lu_solve(
2738
A_decomp,
2839
b,
40+
b_ndim=b_ndim,
2941
trans=transposed,
30-
b_ndim=core_solve_op.b_ndim,
31-
check_finite=core_solve_op.check_finite,
42+
check_finite=check_finite,
43+
)
44+
elif assume_a == "tridiagonal":
45+
# We didn't implement check_finite for tridiagonal LU solve
46+
return tridiagonal_lu_solve(
47+
A_decomp,
48+
b,
49+
b_ndim=b_ndim,
50+
transposed=transposed,
3251
)
3352
else:
3453
raise NotImplementedError
@@ -189,13 +208,15 @@ def _scan_split_non_sequence_lu_decomposition_solve(
189208
@register_specialize
190209
@node_rewriter([Blockwise])
191210
def reuse_lu_decomposition_multiple_solves(fgraph, node):
192-
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
211+
return _split_lu_solve_steps(
212+
fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"}
213+
)
193214

194215

195216
@node_rewriter([Scan])
196217
def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
197218
return _scan_split_non_sequence_lu_decomposition_solve(
198-
fgraph, node, allowed_assume_a={"gen"}
219+
fgraph, node, allowed_assume_a={"gen", "tridiagonal"}
199220
)
200221

201222

@@ -207,3 +228,32 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node):
207228
"scan_pushout",
208229
position=2,
209230
)
231+
232+
233+
@node_rewriter([Blockwise])
234+
def reuse_lu_decomposition_multiple_solves_jax(fgraph, node):
235+
return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"})
236+
237+
238+
optdb["specialize"].register(
239+
reuse_lu_decomposition_multiple_solves_jax.__name__,
240+
in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True),
241+
"jax",
242+
use_db_name_as_tag=False,
243+
)
244+
245+
246+
@node_rewriter([Scan])
247+
def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node):
248+
return _scan_split_non_sequence_lu_decomposition_solve(
249+
fgraph, node, allowed_assume_a={"gen"}
250+
)
251+
252+
253+
scan_seqopt1.register(
254+
scan_split_non_sequence_lu_decomposition_solve_jax.__name__,
255+
in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True),
256+
"jax",
257+
use_db_name_as_tag=False,
258+
position=2,
259+
)
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import numpy as np
2+
from scipy.linalg import get_lapack_funcs
3+
4+
from pytensor.graph import Apply, Op
5+
from pytensor.tensor.basic import as_tensor, diagonal
6+
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.type import tensor, vector
8+
9+
10+
class LUFactorTridiagonal(Op):
11+
"""Compute LU factorization of a tridiagonal matrix (lapack gttrf)"""
12+
13+
__props__ = (
14+
"overwrite_dl",
15+
"overwrite_d",
16+
"overwrite_du",
17+
)
18+
gufunc_signature = "(dl),(d),(dl)->(dl),(d),(dl),(du2),(d)"
19+
20+
def __init__(self, overwrite_dl=False, overwrite_d=False, overwrite_du=False):
21+
self.destroy_map = dm = {}
22+
if overwrite_dl:
23+
dm[0] = [0]
24+
if overwrite_d:
25+
dm[1] = [1]
26+
if overwrite_du:
27+
dm[2] = [2]
28+
self.overwrite_dl = overwrite_dl
29+
self.overwrite_d = overwrite_d
30+
self.overwrite_du = overwrite_du
31+
super().__init__()
32+
33+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
34+
# We need to be able to overwrite the inputs in place
35+
# but we don't want to overwrite the outputs
36+
return type(self)(
37+
overwrite_dl=0 in allowed_inplace_inputs,
38+
overwrite_d=1 in allowed_inplace_inputs,
39+
overwrite_du=2 in allowed_inplace_inputs,
40+
)
41+
42+
def make_node(self, dl, d, du):
43+
dl, d, du = map(as_tensor, (dl, d, du))
44+
45+
if not all(inp.type.ndim == 1 for inp in (dl, d, du)):
46+
raise ValueError("Diagonals must be vectors")
47+
48+
ndl, nd, ndu = (inp.type.shape[-1] for inp in (dl, d, du))
49+
n = (
50+
ndl + 1
51+
if ndl is not None
52+
else (nd if nd is not None else (ndu + 1 if ndu is not None else None))
53+
)
54+
dummy_arrays = [np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du)]
55+
out_dtype = get_lapack_funcs("gttrf", dummy_arrays).dtype
56+
outputs = [
57+
vector(shape=(None if n is None else (n - 1),), dtype=out_dtype),
58+
vector(shape=(n,), dtype=out_dtype),
59+
vector(shape=(None if n is None else n - 1,), dtype=out_dtype),
60+
vector(shape=(None if n is None else n - 2,), dtype=out_dtype),
61+
vector(shape=(n,), dtype=np.int32),
62+
]
63+
return Apply(self, [dl, d, du], outputs)
64+
65+
def perform(self, node, inputs, output_storage):
66+
gttrf = get_lapack_funcs("gttrf", dtype=node.outputs[0].type.dtype)
67+
dl, d, du, du2, ipiv, _ = gttrf(
68+
*inputs,
69+
overwrite_dl=self.overwrite_dl,
70+
overwrite_d=self.overwrite_d,
71+
overwrite_du=self.overwrite_du,
72+
)
73+
output_storage[0][0] = dl
74+
output_storage[1][0] = d
75+
output_storage[2][0] = du
76+
output_storage[3][0] = du2
77+
output_storage[4][0] = ipiv
78+
79+
80+
class SolveLUFactorTridiagonal(Op):
81+
"""Solve a system of linear equations with a tridiagonal coefficient matrix (lapack gttrs)."""
82+
83+
__props__ = ("b_ndim", "overwrite_b", "transposed")
84+
85+
def __init__(self, b_ndim: int, transposed: bool, overwrite_b=False):
86+
if b_ndim not in (1, 2):
87+
raise ValueError("b_ndim must be 1 or 2")
88+
if b_ndim == 1:
89+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d)->(d)"
90+
else:
91+
self.gufunc_signature = "(dl),(d),(dl),(du2),(d),(d,rhs)->(d,rhs)"
92+
if overwrite_b:
93+
self.destroy_map = {0: [5]}
94+
self.b_ndim = b_ndim
95+
self.transposed = transposed
96+
self.overwrite_b = overwrite_b
97+
super().__init__()
98+
99+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
100+
if 5 in allowed_inplace_inputs:
101+
props = self._props_dict()
102+
props["overwrite_b"] = True
103+
return type(self)(**props)
104+
105+
return self
106+
107+
def make_node(self, dl, d, du, du2, ipiv, b):
108+
dl, d, du, du2, ipiv, b = map(as_tensor, (dl, d, du, du2, ipiv, b))
109+
110+
if b.type.ndim != self.b_ndim:
111+
raise ValueError("Wrang number of dimensions for input b.")
112+
113+
if not all(inp.type.ndim == 1 for inp in (dl, d, du, du2, ipiv)):
114+
raise ValueError("Inputs must be vectors")
115+
116+
ndl, nd, ndu, ndu2, nipiv = (
117+
inp.type.shape[-1] for inp in (dl, d, du, du2, ipiv)
118+
)
119+
nb = b.type.shape[0]
120+
n = (
121+
ndl + 1
122+
if ndl is not None
123+
else (
124+
nd
125+
if nd is not None
126+
else (
127+
ndu + 1
128+
if ndu is not None
129+
else (
130+
ndu2 + 2
131+
if ndu2 is not None
132+
else (nipiv if nipiv is not None else nb)
133+
)
134+
)
135+
)
136+
)
137+
dummy_arrays = [
138+
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
139+
]
140+
# Seems to always be float64?
141+
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
142+
if self.b_ndim == 1:
143+
output_shape = (n,)
144+
else:
145+
output_shape = (n, b.type.shape[-1])
146+
147+
outputs = [tensor(shape=output_shape, dtype=out_dtype)]
148+
return Apply(self, [dl, d, du, du2, ipiv, b], outputs)
149+
150+
def perform(self, node, inputs, output_storage):
151+
gttrs = get_lapack_funcs("gttrs", dtype=node.outputs[0].type.dtype)
152+
x, _ = gttrs(
153+
*inputs,
154+
overwrite_b=self.overwrite_b,
155+
trans="N" if not self.transposed else "T",
156+
)
157+
output_storage[0][0] = x
158+
159+
160+
def tridiagonal_lu_factor(a):
161+
# Return the decomposition of A implied by a solve tridiagonal
162+
dl, d, du = (diagonal(a, offset=o, axis1=-2, axis2=-1) for o in (-1, 0, 1))
163+
dl, d, du, du2, ipiv = Blockwise(LUFactorTridiagonal())(dl, d, du)
164+
return dl, d, du, du2, ipiv
165+
166+
167+
def tridiagonal_lu_solve(a_diagonals, b, *, b_ndim: int, transposed: bool = False):
168+
dl, d, du, du2, ipiv = a_diagonals
169+
return Blockwise(SolveLUFactorTridiagonal(b_ndim=b_ndim, transposed=transposed))(
170+
dl, d, du, du2, ipiv, b
171+
)

0 commit comments

Comments
 (0)