github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgetri.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dgetrier interface {
    17  	Dgetrfer
    18  	Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) bool
    19  }
    20  
    21  func DgetriTest(t *testing.T, impl Dgetrier) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	bi := blas64.Implementation()
    24  	for _, test := range []struct {
    25  		n, lda int
    26  	}{
    27  		{5, 0},
    28  		{5, 8},
    29  		{45, 0},
    30  		{45, 50},
    31  		{65, 0},
    32  		{65, 70},
    33  		{150, 0},
    34  		{150, 250},
    35  	} {
    36  		n := test.n
    37  		lda := test.lda
    38  		if lda == 0 {
    39  			lda = n
    40  		}
    41  		// Generate a random well conditioned matrix
    42  		perm := rnd.Perm(n)
    43  		a := make([]float64, n*lda)
    44  		for i := 0; i < n; i++ {
    45  			a[i*lda+perm[i]] = 1
    46  		}
    47  		for i := range a {
    48  			a[i] += 0.01 * rnd.Float64()
    49  		}
    50  		aCopy := make([]float64, len(a))
    51  		copy(aCopy, a)
    52  		ipiv := make([]int, n)
    53  		// Compute LU decomposition.
    54  		impl.Dgetrf(n, n, a, lda, ipiv)
    55  		// Compute inverse.
    56  		work := make([]float64, 1)
    57  		impl.Dgetri(n, a, lda, ipiv, work, -1)
    58  		work = make([]float64, int(work[0]))
    59  		lwork := len(work)
    60  
    61  		ok := impl.Dgetri(n, a, lda, ipiv, work, lwork)
    62  		if !ok {
    63  			t.Errorf("Unexpected singular matrix.")
    64  		}
    65  
    66  		// Check that A(inv) * A = I.
    67  		ans := make([]float64, len(a))
    68  		bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, aCopy, lda, a, lda, 0, ans, lda)
    69  		isEye := true
    70  		for i := 0; i < n; i++ {
    71  			for j := 0; j < n; j++ {
    72  				if i == j {
    73  					// This tolerance is so high because computing matrix inverses
    74  					// is very unstable.
    75  					if math.Abs(ans[i*lda+j]-1) > 5e-2 {
    76  						isEye = false
    77  					}
    78  				} else {
    79  					if math.Abs(ans[i*lda+j]) > 5e-2 {
    80  						isEye = false
    81  					}
    82  				}
    83  			}
    84  		}
    85  		if !isEye {
    86  			t.Errorf("Inv(A) * A != I. n = %v, lda = %v", n, lda)
    87  		}
    88  	}
    89  }