Skip to content

Commit 45b127d

Browse files
committed
Merge branch 'pr/196'
2 parents 22b5eb0 + bebf376 commit 45b127d

File tree

8 files changed

+1285
-1
lines changed

8 files changed

+1285
-1
lines changed

examples/tridiagonal.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use ndarray::*;
2+
use ndarray_linalg::*;
3+
4+
// Solve `Ax=b` for tridiagonal matrix
5+
fn solve() -> Result<(), error::LinalgError> {
6+
let mut a: Array2<f64> = random((3, 3));
7+
let b: Array1<f64> = random(3);
8+
a[[0, 2]] = 0.0;
9+
a[[2, 0]] = 0.0;
10+
let _x = a.solve_tridiagonal(&b)?;
11+
Ok(())
12+
}
13+
14+
// Solve `Ax=b` for many b with fixed A
15+
fn factorize() -> Result<(), error::LinalgError> {
16+
let mut a: Array2<f64> = random((3, 3));
17+
a[[0, 2]] = 0.0;
18+
a[[2, 0]] = 0.0;
19+
let f = a.factorize_tridiagonal()?; // LU factorize A (A is *not* consumed)
20+
for _ in 0..10 {
21+
let b: Array1<f64> = random(3);
22+
let _x = f.solve_tridiagonal_into(b)?; // solve Ax=b using factorized L, U
23+
}
24+
Ok(())
25+
}
26+
27+
fn main() {
28+
solve().unwrap();
29+
factorize().unwrap();
30+
}

src/error.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ pub enum LinalgError {
1717
InvalidStride { s0: Ixs, s1: Ixs },
1818
/// Memory is not aligned continously
1919
MemoryNotCont,
20+
/// Obj cannot be made from a (rows, cols) matrix
21+
NotStandardShape {
22+
obj: &'static str,
23+
rows: i32,
24+
cols: i32,
25+
},
2026
/// Strides of the array is not supported
2127
Shape(ShapeError),
2228
}
@@ -34,6 +40,11 @@ impl fmt::Display for LinalgError {
3440
write!(f, "invalid stride: s0={}, s1={}", s0, s1)
3541
}
3642
LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"),
43+
LinalgError::NotStandardShape { obj, rows, cols } => write!(
44+
f,
45+
"{} cannot be made from a ({}, {}) matrix",
46+
obj, rows, cols
47+
),
3748
LinalgError::Shape(err) => write!(f, "Shape Error: {}", err),
3849
}
3950
}

src/lapack/mod.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod solveh;
1111
pub mod svd;
1212
pub mod svddc;
1313
pub mod triangular;
14+
pub mod tridiagonal;
1415

1516
pub use self::cholesky::*;
1617
pub use self::eig::*;
@@ -23,6 +24,7 @@ pub use self::solveh::*;
2324
pub use self::svd::*;
2425
pub use self::svddc::*;
2526
pub use self::triangular::*;
27+
pub use self::tridiagonal::*;
2628

2729
use super::error::*;
2830
use super::types::*;
@@ -31,7 +33,17 @@ pub type Pivot = Vec<i32>;
3133

3234
/// Trait for primitive types which implements LAPACK subroutines
3335
pub trait Lapack:
34-
OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eig_ + Eigh_ + Triangular_
36+
OperatorNorm_
37+
+ QR_
38+
+ SVD_
39+
+ SVDDC_
40+
+ Solve_
41+
+ Solveh_
42+
+ Cholesky_
43+
+ Eig_
44+
+ Eigh_
45+
+ Triangular_
46+
+ Tridiagonal_
3547
{
3648
}
3749

src/lapack/tridiagonal.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//! Implement linear solver using LU decomposition
2+
//! for tridiagonal matrix
3+
4+
use lapacke;
5+
use num_traits::Zero;
6+
7+
use super::NormType;
8+
use super::{into_result, Pivot, Transpose};
9+
10+
use crate::error::*;
11+
use crate::layout::MatrixLayout;
12+
use crate::opnorm::*;
13+
use crate::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};
14+
use crate::types::*;
15+
16+
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
17+
pub trait Tridiagonal_: Scalar + Sized {
18+
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
19+
/// partial pivoting with row interchanges.
20+
unsafe fn lu_tridiagonal(a: &mut Tridiagonal<Self>) -> Result<(Vec<Self>, Self::Real, Pivot)>;
21+
/// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
22+
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
23+
unsafe fn solve_tridiagonal(
24+
lu: &LUFactorizedTridiagonal<Self>,
25+
bl: MatrixLayout,
26+
t: Transpose,
27+
b: &mut [Self],
28+
) -> Result<()>;
29+
}
30+
31+
macro_rules! impl_tridiagonal {
32+
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
33+
impl Tridiagonal_ for $scalar {
34+
unsafe fn lu_tridiagonal(
35+
a: &mut Tridiagonal<Self>,
36+
) -> Result<(Vec<Self>, Self::Real, Pivot)> {
37+
let (n, _) = a.l.size();
38+
let anom = a.opnorm_one()?;
39+
let mut du2 = vec![Zero::zero(); (n - 2) as usize];
40+
let mut ipiv = vec![0; n as usize];
41+
let info = $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv);
42+
into_result(info, (du2, anom, ipiv))
43+
}
44+
45+
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
46+
let (n, _) = lu.a.l.size();
47+
let ipiv = &lu.ipiv;
48+
let anorm = lu.anom;
49+
let mut rcond = Self::Real::zero();
50+
let info = $gtcon(
51+
NormType::One as u8,
52+
n,
53+
&lu.a.dl,
54+
&lu.a.d,
55+
&lu.a.du,
56+
&lu.du2,
57+
ipiv,
58+
anorm,
59+
&mut rcond,
60+
);
61+
into_result(info, rcond)
62+
}
63+
64+
unsafe fn solve_tridiagonal(
65+
lu: &LUFactorizedTridiagonal<Self>,
66+
bl: MatrixLayout,
67+
t: Transpose,
68+
b: &mut [Self],
69+
) -> Result<()> {
70+
let (n, _) = lu.a.l.size();
71+
let (_, nrhs) = bl.size();
72+
let ipiv = &lu.ipiv;
73+
let ldb = bl.lda();
74+
let info = $gttrs(
75+
lu.a.l.lapacke_layout(),
76+
t as u8,
77+
n,
78+
nrhs,
79+
&lu.a.dl,
80+
&lu.a.d,
81+
&lu.a.du,
82+
&lu.du2,
83+
ipiv,
84+
b,
85+
ldb,
86+
);
87+
into_result(info, ())
88+
}
89+
}
90+
};
91+
} // impl_tridiagonal!
92+
93+
impl_tridiagonal!(f64, lapacke::dgttrf, lapacke::dgtcon, lapacke::dgttrs);
94+
impl_tridiagonal!(f32, lapacke::sgttrf, lapacke::sgtcon, lapacke::sgttrs);
95+
impl_tridiagonal!(c64, lapacke::zgttrf, lapacke::zgtcon, lapacke::zgttrs);
96+
impl_tridiagonal!(c32, lapacke::cgttrf, lapacke::cgtcon, lapacke::cgttrs);

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//! - [General matrices](solve/index.html)
1515
//! - [Triangular matrices](triangular/index.html)
1616
//! - [Hermitian/real symmetric matrices](solveh/index.html)
17+
//! - [Tridiagonal matrices](tridiagonal/index.html)
1718
//! - [Inverse matrix computation](solve/trait.Inverse.html)
1819
//!
1920
//! Naming Convention
@@ -67,6 +68,7 @@ pub mod svd;
6768
pub mod svddc;
6869
pub mod trace;
6970
pub mod triangular;
71+
pub mod tridiagonal;
7072
pub mod types;
7173

7274
pub use assert::*;
@@ -90,4 +92,5 @@ pub use svd::*;
9092
pub use svddc::*;
9193
pub use trace::*;
9294
pub use triangular::*;
95+
pub use tridiagonal::*;
9396
pub use types::*;

src/opnorm.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
33
use ndarray::*;
44

5+
use crate::convert::*;
56
use crate::error::*;
67
use crate::layout::*;
8+
use crate::tridiagonal::Tridiagonal;
79
use crate::types::*;
810

911
pub use crate::lapack::NormType;
@@ -46,3 +48,66 @@ where
4648
Ok(unsafe { A::opnorm(t, l, a) })
4749
}
4850
}
51+
52+
impl<A> OperationNorm for Tridiagonal<A>
53+
where
54+
A: Scalar + Lapack,
55+
{
56+
type Output = A::Real;
57+
58+
fn opnorm(&self, t: NormType) -> Result<Self::Output> {
59+
// `self` is a tridiagonal matrix like,
60+
// [d0, u1, 0, ..., 0,
61+
// l1, d1, u2, ...,
62+
// 0, l2, d2,
63+
// ... ..., u{n-1},
64+
// 0, ..., l{n-1}, d{n-1},]
65+
let arr = match t {
66+
// opnorm_one() calculates muximum column sum.
67+
// Therefore, This part align the columns and make a (3 x n) matrix like,
68+
// [ 0, u1, u2, ..., u{n-1},
69+
// d0, d1, d2, ..., d{n-1},
70+
// l1, l2, l3, ..., 0,]
71+
NormType::One => {
72+
let zl: Array1<A> = Array::zeros(1);
73+
let zu: Array1<A> = Array::zeros(1);
74+
let dl = stack![Axis(0), self.dl.to_owned(), zl];
75+
let du = stack![Axis(0), zu, self.du.to_owned()];
76+
let arr = stack![Axis(0), into_row(du), into_row(arr1(&self.d)), into_row(dl)];
77+
arr
78+
}
79+
// opnorm_inf() calculates muximum row sum.
80+
// Therefore, This part align the rows and make a (n x 3) matrix like,
81+
// [ 0, d0, u1,
82+
// l1, d1, u2,
83+
// l2, d2, u3,
84+
// ..., ..., ...,
85+
// l{n-1}, d{n-1}, 0,]
86+
NormType::Infinity => {
87+
let zl: Array1<A> = Array::zeros(1);
88+
let zu: Array1<A> = Array::zeros(1);
89+
let dl = stack![Axis(0), zl, self.dl.to_owned()];
90+
let du = stack![Axis(0), self.du.to_owned(), zu];
91+
let arr = stack![Axis(1), into_col(dl), into_col(arr1(&self.d)), into_col(du)];
92+
arr
93+
}
94+
// opnorm_fro() calculates square root of sum of squares.
95+
// Because it is independent of the shape of matrix,
96+
// this part make a (1 x (3n-2)) matrix like,
97+
// [l1, ..., l{n-1}, d0, ..., d{n-1}, u1, ..., u{n-1}]
98+
NormType::Frobenius => {
99+
let arr = stack![
100+
Axis(1),
101+
into_row(arr1(&self.dl)),
102+
into_row(arr1(&self.d)),
103+
into_row(arr1(&self.du))
104+
];
105+
arr
106+
}
107+
};
108+
109+
let l = arr.layout()?;
110+
let a = arr.as_allocated()?;
111+
Ok(unsafe { A::opnorm(t, l, a) })
112+
}
113+
}

0 commit comments

Comments
 (0)