Skip to content

Commit acd7858

Browse files
authored
Merge pull request #346 from rust-ndarray/lax-solve-impl
Merge `Solve_`, `Solveh_` and `Cholesky_` into `Lapack` trait
2 parents 07ab31d + 7e61539 commit acd7858

File tree

4 files changed

+494
-341
lines changed

4 files changed

+494
-341
lines changed

Diff for: lax/src/cholesky.rs

+63-72
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,25 @@
1+
//! Factorize positive-definite symmetric/Hermitian matrices using Cholesky algorithm
2+
13
use super::*;
24
use crate::{error::*, layout::*};
35
use cauchy::*;
46

5-
#[cfg_attr(doc, katexit::katexit)]
6-
/// Solve symmetric/hermite positive-definite linear equations using Cholesky decomposition
7-
///
8-
/// For a given positive definite matrix $A$,
9-
/// Cholesky decomposition is described as $A = U^T U$ or $A = LL^T$ where
7+
/// Compute Cholesky decomposition according to [UPLO]
108
///
11-
/// - $L$ is lower matrix
12-
/// - $U$ is upper matrix
9+
/// LAPACK correspondance
10+
/// ----------------------
1311
///
14-
/// This is designed as two step computation according to LAPACK API
12+
/// | f32 | f64 | c32 | c64 |
13+
/// |:-------|:-------|:-------|:-------|
14+
/// | spotrf | dpotrf | cpotrf | zpotrf |
1515
///
16-
/// 1. Factorize input matrix $A$ into $L$ or $U$
17-
/// 2. Solve linear equation $Ax = b$ or compute inverse matrix $A^{-1}$
18-
/// using $U$ or $L$.
19-
pub trait Cholesky_: Sized {
20-
/// Compute Cholesky decomposition $A = U^T U$ or $A = L L^T$ according to [UPLO]
21-
///
22-
/// LAPACK correspondance
23-
/// ----------------------
24-
///
25-
/// | f32 | f64 | c32 | c64 |
26-
/// |:-------|:-------|:-------|:-------|
27-
/// | spotrf | dpotrf | cpotrf | zpotrf |
28-
///
16+
pub trait CholeskyImpl: Scalar {
2917
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
30-
31-
/// Compute inverse matrix $A^{-1}$ using $U$ or $L$
32-
///
33-
/// LAPACK correspondance
34-
/// ----------------------
35-
///
36-
/// | f32 | f64 | c32 | c64 |
37-
/// |:-------|:-------|:-------|:-------|
38-
/// | spotri | dpotri | cpotri | zpotri |
39-
///
40-
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
41-
42-
/// Solve linear equation $Ax = b$ using $U$ or $L$
43-
///
44-
/// LAPACK correspondance
45-
/// ----------------------
46-
///
47-
/// | f32 | f64 | c32 | c64 |
48-
/// |:-------|:-------|:-------|:-------|
49-
/// | spotrs | dpotrs | cpotrs | zpotrs |
50-
///
51-
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
5218
}
5319

54-
macro_rules! impl_cholesky {
55-
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
56-
impl Cholesky_ for $scalar {
20+
macro_rules! impl_cholesky_ {
21+
($s:ty, $trf:path) => {
22+
impl CholeskyImpl for $s {
5723
fn cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
5824
let (n, _) = l.size();
5925
if matches!(l, MatrixLayout::C { .. }) {
@@ -69,7 +35,30 @@ macro_rules! impl_cholesky {
6935
}
7036
Ok(())
7137
}
38+
}
39+
};
40+
}
41+
impl_cholesky_!(c64, lapack_sys::zpotrf_);
42+
impl_cholesky_!(c32, lapack_sys::cpotrf_);
43+
impl_cholesky_!(f64, lapack_sys::dpotrf_);
44+
impl_cholesky_!(f32, lapack_sys::spotrf_);
45+
46+
/// Compute inverse matrix using Cholesky factroization result
47+
///
48+
/// LAPACK correspondance
49+
/// ----------------------
50+
///
51+
/// | f32 | f64 | c32 | c64 |
52+
/// |:-------|:-------|:-------|:-------|
53+
/// | spotri | dpotri | cpotri | zpotri |
54+
///
55+
pub trait InvCholeskyImpl: Scalar {
56+
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()>;
57+
}
7258

59+
macro_rules! impl_inv_cholesky {
60+
($s:ty, $tri:path) => {
61+
impl InvCholeskyImpl for $s {
7362
fn inv_cholesky(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<()> {
7463
let (n, _) = l.size();
7564
if matches!(l, MatrixLayout::C { .. }) {
@@ -85,7 +74,30 @@ macro_rules! impl_cholesky {
8574
}
8675
Ok(())
8776
}
77+
}
78+
};
79+
}
80+
impl_inv_cholesky!(c64, lapack_sys::zpotri_);
81+
impl_inv_cholesky!(c32, lapack_sys::cpotri_);
82+
impl_inv_cholesky!(f64, lapack_sys::dpotri_);
83+
impl_inv_cholesky!(f32, lapack_sys::spotri_);
8884

85+
/// Solve linear equation using Cholesky factroization result
86+
///
87+
/// LAPACK correspondance
88+
/// ----------------------
89+
///
90+
/// | f32 | f64 | c32 | c64 |
91+
/// |:-------|:-------|:-------|:-------|
92+
/// | spotrs | dpotrs | cpotrs | zpotrs |
93+
///
94+
pub trait SolveCholeskyImpl: Scalar {
95+
fn solve_cholesky(l: MatrixLayout, uplo: UPLO, a: &[Self], b: &mut [Self]) -> Result<()>;
96+
}
97+
98+
macro_rules! impl_solve_cholesky {
99+
($s:ty, $trs:path) => {
100+
impl SolveCholeskyImpl for $s {
89101
fn solve_cholesky(
90102
l: MatrixLayout,
91103
mut uplo: UPLO,
@@ -123,29 +135,8 @@ macro_rules! impl_cholesky {
123135
}
124136
}
125137
};
126-
} // end macro_rules
127-
128-
impl_cholesky!(
129-
f64,
130-
lapack_sys::dpotrf_,
131-
lapack_sys::dpotri_,
132-
lapack_sys::dpotrs_
133-
);
134-
impl_cholesky!(
135-
f32,
136-
lapack_sys::spotrf_,
137-
lapack_sys::spotri_,
138-
lapack_sys::spotrs_
139-
);
140-
impl_cholesky!(
141-
c64,
142-
lapack_sys::zpotrf_,
143-
lapack_sys::zpotri_,
144-
lapack_sys::zpotrs_
145-
);
146-
impl_cholesky!(
147-
c32,
148-
lapack_sys::cpotrf_,
149-
lapack_sys::cpotri_,
150-
lapack_sys::cpotrs_
151-
);
138+
}
139+
impl_solve_cholesky!(c64, lapack_sys::zpotrs_);
140+
impl_solve_cholesky!(c32, lapack_sys::cpotrs_);
141+
impl_solve_cholesky!(f64, lapack_sys::dpotrs_);
142+
impl_solve_cholesky!(f32, lapack_sys::spotrs_);

0 commit comments

Comments
 (0)