gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgesv.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/blas/blas64"
    16  	"gonum.org/v1/gonum/lapack"
    17  )
    18  
    19  type Dgesver interface {
    20  	Dgesv(n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int) bool
    21  
    22  	Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool
    23  }
    24  
    25  func DgesvTest(t *testing.T, impl Dgesver) {
    26  	rnd := rand.New(rand.NewSource(1))
    27  	for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 50, 100} {
    28  		for _, nrhs := range []int{0, 1, 2, 5} {
    29  			for _, lda := range []int{max(1, n), n + 5} {
    30  				for _, ldb := range []int{max(1, nrhs), nrhs + 5} {
    31  					dgesvTest(t, impl, rnd, n, nrhs, lda, ldb)
    32  				}
    33  			}
    34  		}
    35  	}
    36  }
    37  
    38  func dgesvTest(t *testing.T, impl Dgesver, rnd *rand.Rand, n, nrhs, lda, ldb int) {
    39  	const tol = 1e-15
    40  
    41  	name := fmt.Sprintf("n=%v,nrhs=%v,lda=%v,ldb=%v", n, nrhs, lda, ldb)
    42  
    43  	// Create a random system matrix A and the solution X.
    44  	a := randomGeneral(n, n, lda, rnd)
    45  	xWant := randomGeneral(n, nrhs, ldb, rnd)
    46  
    47  	// Compute the right hand side matrix B = A*X.
    48  	b := zeros(n, nrhs, ldb)
    49  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, a, xWant, 0, b)
    50  
    51  	// Allocate a slice for row pivots and fill it with invalid indices.
    52  	ipiv := make([]int, n)
    53  	for i := range ipiv {
    54  		ipiv[i] = -1
    55  	}
    56  
    57  	// Call Dgesv to solve A*X = B.
    58  	lu := cloneGeneral(a)
    59  	xGot := cloneGeneral(b)
    60  	ok := impl.Dgesv(n, nrhs, lu.Data, lu.Stride, ipiv, xGot.Data, xGot.Stride)
    61  
    62  	if !ok {
    63  		t.Errorf("%v: unexpected failure in Dgesv", name)
    64  		return
    65  	}
    66  
    67  	if n == 0 || nrhs == 0 {
    68  		return
    69  	}
    70  
    71  	// Check that all elements of ipiv have been set.
    72  	ipivSet := true
    73  	for _, ipv := range ipiv {
    74  		if ipv == -1 {
    75  			ipivSet = false
    76  			break
    77  		}
    78  	}
    79  	if !ipivSet {
    80  		t.Fatalf("%v: not all elements of ipiv set", name)
    81  		return
    82  	}
    83  
    84  	// Compute the reciprocal of the condition number of A from its LU
    85  	// decomposition before it's overwritten further below.
    86  	aInv := cloneGeneral(lu)
    87  	impl.Dgetri(n, aInv.Data, aInv.Stride, ipiv, make([]float64, n), n)
    88  	ainvnorm := dlange(lapack.MaxColumnSum, n, n, aInv.Data, aInv.Stride)
    89  	anorm := dlange(lapack.MaxColumnSum, n, n, a.Data, a.Stride)
    90  	rcond := 1 / anorm / ainvnorm
    91  
    92  	// Reconstruct matrix A from factors and compute residual.
    93  	//
    94  	// Extract L and U from lu.
    95  	l := zeros(n, n, n)
    96  	u := zeros(n, n, n)
    97  	for i := 0; i < n; i++ {
    98  		for j := 0; j < i; j++ {
    99  			l.Data[i*l.Stride+j] = lu.Data[i*lu.Stride+j]
   100  		}
   101  		l.Data[i*l.Stride+i] = 1
   102  		for j := i; j < n; j++ {
   103  			u.Data[i*u.Stride+j] = lu.Data[i*lu.Stride+j]
   104  		}
   105  	}
   106  	// Compute L*U.
   107  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, lu)
   108  	// Apply P to L*U.
   109  	for i := n - 1; i >= 0; i-- {
   110  		ip := ipiv[i]
   111  		if ip == i {
   112  			continue
   113  		}
   114  		row1 := blas64.Vector{N: n, Data: lu.Data[i*lu.Stride:], Inc: 1}
   115  		row2 := blas64.Vector{N: n, Data: lu.Data[ip*lu.Stride:], Inc: 1}
   116  		blas64.Swap(row1, row2)
   117  	}
   118  	// Compute P*L*U - A.
   119  	for i := 0; i < n; i++ {
   120  		for j := 0; j < n; j++ {
   121  			lu.Data[i*lu.Stride+j] -= a.Data[i*a.Stride+j]
   122  		}
   123  	}
   124  	// Compute the residual |P*L*U - A|.
   125  	resid := dlange(lapack.MaxColumnSum, n, n, lu.Data, lu.Stride)
   126  	resid /= float64(n) * anorm
   127  	if resid > tol || math.IsNaN(resid) {
   128  		t.Errorf("%v: residual |P*L*U - A| is too large, got %v, want <= %v", name, resid, tol)
   129  	}
   130  
   131  	// Compute residual of the computed solution.
   132  	//
   133  	// Compute B - A*X.
   134  	blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, a, xGot, 1, b)
   135  	// Compute the maximum over the number of right hand sides of |B - A*X| / (|A| * |X|).
   136  	resid = 0
   137  	for j := 0; j < nrhs; j++ {
   138  		bnorm := blas64.Asum(blas64.Vector{N: n, Data: b.Data[j:], Inc: b.Stride})
   139  		xnorm := blas64.Asum(blas64.Vector{N: n, Data: xGot.Data[j:], Inc: xGot.Stride})
   140  		resid = math.Max(resid, bnorm/anorm/xnorm)
   141  	}
   142  	if resid > tol || math.IsNaN(resid) {
   143  		t.Errorf("%v: residual |B - A*X| is too large, got %v, want <= %v", name, resid, tol)
   144  	}
   145  
   146  	// Compare the computed solution with the generated exact solution.
   147  	//
   148  	// Compute X - XWANT.
   149  	for i := 0; i < n; i++ {
   150  		for j := 0; j < nrhs; j++ {
   151  			xGot.Data[i*xGot.Stride+j] -= xWant.Data[i*xWant.Stride+j]
   152  		}
   153  	}
   154  	// Compute the maximum of |X - XWANT|/|XWANT| over all the vectors X and XWANT.
   155  	resid = 0
   156  	for j := 0; j < nrhs; j++ {
   157  		xnorm := dlange(lapack.MaxAbs, n, 1, xWant.Data[j:], xWant.Stride)
   158  		diff := dlange(lapack.MaxAbs, n, 1, xGot.Data[j:], xGot.Stride)
   159  		resid = math.Max(resid, diff/xnorm*rcond)
   160  	}
   161  	if resid > tol || math.IsNaN(resid) {
   162  		t.Errorf("%v: residual |X-XWANT| is too large, got %v, want <= %v", name, resid, tol)
   163  	}
   164  }