Skip to content

Commit

Permalink
Relax type bounds for LeastSquaresSvd family
Browse files Browse the repository at this point in the history
  • Loading branch information
janmarthedal committed Apr 5, 2021
1 parent 28c7860 commit a23224f
Showing 1 changed file with 67 additions and 32 deletions.
99 changes: 67 additions & 32 deletions ndarray-linalg/src/least_squares.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,13 @@ where

/// Solve least squares for immutable references and a single
/// column vector as a right-hand side.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvd<D, E, Ix1> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix1> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: Data<Elem = E>,
D1: Data<Elem = E>,
D2: Data<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(&rhs)`, where `rhs` is a
Expand All @@ -163,7 +164,7 @@ where
/// `A` and `rhs` must have the same layout, i.e. they must
/// be both either row- or column-major format, otherwise a
/// `IncompatibleShape` error is raised.
fn least_squares(&self, rhs: &ArrayBase<D, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
fn least_squares(&self, rhs: &ArrayBase<D2, Ix1>) -> Result<LeastSquaresResult<E, Ix1>> {
let a = self.to_owned();
let b = rhs.to_owned();
a.least_squares_into(b)
Expand All @@ -172,12 +173,13 @@ where

/// Solve least squares for immutable references and matrix
/// (=mulitipe vectors) as a right-hand side.
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvd<D, E, Ix2> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvd<D2, E, Ix2> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: Data<Elem = E>,
D1: Data<Elem = E>,
D2: Data<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(&rhs)`, where `rhs` is
Expand All @@ -186,7 +188,7 @@ where
/// `A` and `rhs` must have the same layout, i.e. they must
/// be both either row- or column-major format, otherwise a
/// `IncompatibleShape` error is raised.
fn least_squares(&self, rhs: &ArrayBase<D, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
fn least_squares(&self, rhs: &ArrayBase<D2, Ix2>) -> Result<LeastSquaresResult<E, Ix2>> {
let a = self.to_owned();
let b = rhs.to_owned();
a.least_squares_into(b)
Expand All @@ -199,10 +201,11 @@ where
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInto<D, E, Ix1> for ArrayBase<D, Ix2>
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix1> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -213,7 +216,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_into(
mut self,
mut rhs: ArrayBase<D, Ix1>,
mut rhs: ArrayBase<D2, Ix1>,
) -> Result<LeastSquaresResult<E, Ix1>> {
self.least_squares_in_place(&mut rhs)
}
Expand All @@ -223,12 +226,13 @@ where
/// as a right-hand side. The matrix and the RHS matrix
/// are consumed.
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInto<D, E, Ix2> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInto<D2, E, Ix2> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -239,7 +243,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_into(
mut self,
mut rhs: ArrayBase<D, Ix2>,
mut rhs: ArrayBase<D2, Ix2>,
) -> Result<LeastSquaresResult<E, Ix2>> {
self.least_squares_in_place(&mut rhs)
}
Expand All @@ -249,12 +253,13 @@ where
/// as a right-hand side. Both values are overwritten in the
/// call.
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix1> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix1> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -265,7 +270,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_in_place(
&mut self,
rhs: &mut ArrayBase<D, Ix1>,
rhs: &mut ArrayBase<D2, Ix1>,
) -> Result<LeastSquaresResult<E, Ix1>> {
if self.shape()[0] != rhs.shape()[0] {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
Expand Down Expand Up @@ -331,12 +336,13 @@ fn compute_residual_scalar<E: Scalar, D: Data<Elem = E>>(
/// as a right-hand side. Both values are overwritten in the
/// call.
///
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D` can be any
/// valid representation for `ArrayBase`.
impl<E, D> LeastSquaresSvdInPlace<D, E, Ix2> for ArrayBase<D, Ix2>
/// `E` is one of `f32`, `f64`, `c32`, `c64`. `D1`, `D2` can be any
/// valid representation for `ArrayBase` (over `E`).
impl<E, D1, D2> LeastSquaresSvdInPlace<D2, E, Ix2> for ArrayBase<D1, Ix2>
where
E: Scalar + Lapack + LeastSquaresSvdDivideConquer_,
D: DataMut<Elem = E>,
D1: DataMut<Elem = E>,
D2: DataMut<Elem = E>,
{
/// Solve a least squares problem of the form `Ax = rhs`
/// by calling `A.least_squares(rhs)`, where `rhs` is a
Expand All @@ -347,7 +353,7 @@ where
/// `IncompatibleShape` error is raised.
fn least_squares_in_place(
&mut self,
rhs: &mut ArrayBase<D, Ix2>,
rhs: &mut ArrayBase<D2, Ix2>,
) -> Result<LeastSquaresResult<E, Ix2>> {
if self.shape()[0] != rhs.shape()[0] {
return Err(ShapeError::from_kind(ErrorKind::IncompatibleShape).into());
Expand Down Expand Up @@ -425,7 +431,7 @@ mod tests {
use ndarray::*;

//
// Test that the different lest squares traits work as intended on the
// Test that the different least squares traits work as intended on the
// different array types.
//
// | least_squares | ls_into | ls_in_place |
Expand All @@ -437,9 +443,9 @@ mod tests {
// ArrayViewMut | yes | no | yes |
//

fn assert_result<D: Data<Elem = f64>>(
a: &ArrayBase<D, Ix2>,
b: &ArrayBase<D, Ix1>,
fn assert_result<D1: Data<Elem = f64>, D2: Data<Elem = f64>>(
a: &ArrayBase<D1, Ix2>,
b: &ArrayBase<D2, Ix1>,
res: &LeastSquaresResult<f64, Ix1>,
) {
assert_eq!(res.rank, 2);
Expand Down Expand Up @@ -487,6 +493,15 @@ mod tests {
assert_result(&av, &bv, &res);
}

#[test]
fn on_cow_view() {
let a = CowArray::from(array![[1., 2.], [4., 5.], [3., 4.]]);
let b: Array1<f64> = array![1., 2., 3.];
let bv = b.view();
let res = a.least_squares(&bv).unwrap();
assert_result(&a, &bv, &res);
}

#[test]
fn into_on_owned() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
Expand Down Expand Up @@ -517,6 +532,16 @@ mod tests {
assert_result(&a, &b, &res);
}

#[test]
fn into_on_owned_cow() {
let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]];
let b = CowArray::from(array![1., 2., 3.]);
let ac = a.clone();
let b2 = b.clone();
let res = ac.least_squares_into(b2).unwrap();
assert_result(&a, &b, &res);
}

#[test]
fn in_place_on_owned() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
Expand Down Expand Up @@ -549,6 +574,16 @@ mod tests {
assert_result(&a, &b, &res);
}

#[test]
fn in_place_on_owned_cow() {
let a = array![[1., 2.], [4., 5.], [3., 4.]];
let b = CowArray::from(array![1., 2., 3.]);
let mut a2 = a.clone();
let mut b2 = b.clone();
let res = a2.least_squares_in_place(&mut b2).unwrap();
assert_result(&a, &b, &res);
}

//
// Testing error cases
//
Expand Down

0 comments on commit a23224f

Please sign in to comment.