github.com/gopherd/gonum@v0.0.4/lapack/testlapack/dsteqr.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"math/rand"
    11  
    12  	"github.com/gopherd/gonum/blas"
    13  	"github.com/gopherd/gonum/blas/blas64"
    14  	"github.com/gopherd/gonum/floats"
    15  	"github.com/gopherd/gonum/lapack"
    16  )
    17  
    18  type Dsteqrer interface {
    19  	Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool)
    20  	Dorgtrer
    21  }
    22  
    23  func DsteqrTest(t *testing.T, impl Dsteqrer) {
    24  	rnd := rand.New(rand.NewSource(1))
    25  	for _, compz := range []lapack.EVComp{lapack.EVOrig, lapack.EVTridiag} {
    26  		for _, test := range []struct {
    27  			n, lda int
    28  		}{
    29  			{1, 0},
    30  			{4, 0},
    31  			{8, 0},
    32  			{10, 0},
    33  
    34  			{2, 10},
    35  			{8, 10},
    36  			{10, 20},
    37  		} {
    38  			for cas := 0; cas < 100; cas++ {
    39  				n := test.n
    40  				lda := test.lda
    41  				if lda == 0 {
    42  					lda = n
    43  				}
    44  				d := make([]float64, n)
    45  				for i := range d {
    46  					d[i] = rnd.Float64()
    47  				}
    48  				e := make([]float64, n-1)
    49  				for i := range e {
    50  					e[i] = rnd.Float64()
    51  				}
    52  				a := make([]float64, n*lda)
    53  				for i := range a {
    54  					a[i] = rnd.Float64()
    55  				}
    56  				dCopy := make([]float64, len(d))
    57  				copy(dCopy, d)
    58  				eCopy := make([]float64, len(e))
    59  				copy(eCopy, e)
    60  				aCopy := make([]float64, len(a))
    61  				copy(aCopy, a)
    62  				if compz == lapack.EVOrig {
    63  					uplo := blas.Upper
    64  					tau := make([]float64, n)
    65  					work := make([]float64, 1)
    66  					impl.Dsytrd(blas.Upper, n, a, lda, d, e, tau, work, -1)
    67  					work = make([]float64, int(work[0]))
    68  					// Reduce A to symmetric tridiagonal form.
    69  					impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
    70  					// Compute the orthogonal matrix Q.
    71  					impl.Dorgtr(uplo, n, a, lda, tau, work, len(work))
    72  				} else {
    73  					for i := 0; i < n; i++ {
    74  						for j := 0; j < n; j++ {
    75  							a[i*lda+j] = 0
    76  							if i == j {
    77  								a[i*lda+j] = 1
    78  							}
    79  						}
    80  					}
    81  				}
    82  				work := make([]float64, 2*n)
    83  
    84  				aDecomp := make([]float64, len(a))
    85  				copy(aDecomp, a)
    86  				dDecomp := make([]float64, len(d))
    87  				copy(dDecomp, d)
    88  				eDecomp := make([]float64, len(e))
    89  				copy(eDecomp, e)
    90  				impl.Dsteqr(compz, n, d, e, a, lda, work)
    91  				dAns := make([]float64, len(d))
    92  				copy(dAns, d)
    93  
    94  				var truth blas64.General
    95  				if compz == lapack.EVOrig {
    96  					truth = blas64.General{
    97  						Rows:   n,
    98  						Cols:   n,
    99  						Stride: n,
   100  						Data:   make([]float64, n*n),
   101  					}
   102  					for i := 0; i < n; i++ {
   103  						for j := i; j < n; j++ {
   104  							v := aCopy[i*lda+j]
   105  							truth.Data[i*truth.Stride+j] = v
   106  							truth.Data[j*truth.Stride+i] = v
   107  						}
   108  					}
   109  				} else {
   110  					truth = blas64.General{
   111  						Rows:   n,
   112  						Cols:   n,
   113  						Stride: n,
   114  						Data:   make([]float64, n*n),
   115  					}
   116  					for i := 0; i < n; i++ {
   117  						truth.Data[i*truth.Stride+i] = dCopy[i]
   118  						if i != n-1 {
   119  							truth.Data[(i+1)*truth.Stride+i] = eCopy[i]
   120  							truth.Data[i*truth.Stride+i+1] = eCopy[i]
   121  						}
   122  					}
   123  				}
   124  
   125  				V := blas64.General{
   126  					Rows:   n,
   127  					Cols:   n,
   128  					Stride: lda,
   129  					Data:   a,
   130  				}
   131  				if !eigenDecompCorrect(d, truth, V) {
   132  					t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v",
   133  						compz == lapack.EVOrig, n)
   134  				}
   135  
   136  				// Compare eigenvalues when not computing eigenvectors.
   137  				for i := range work {
   138  					work[i] = rnd.Float64()
   139  				}
   140  				impl.Dsteqr(lapack.EVCompNone, n, dDecomp, eDecomp, aDecomp, lda, work)
   141  				if !floats.EqualApprox(d, dAns, 1e-8) {
   142  					t.Errorf("Eigenvalue mismatch when eigenvectors not computed")
   143  				}
   144  			}
   145  		}
   146  	}
   147  }
   148  
   149  // eigenDecompCorrect returns whether the eigen decomposition is correct.
   150  // It checks if
   151  //  A * v ≈ λ * v
   152  // where the eigenvalues λ are stored in values, and the eigenvectors are stored
   153  // in the columns of v.
   154  func eigenDecompCorrect(values []float64, A, V blas64.General) bool {
   155  	n := A.Rows
   156  	for i := 0; i < n; i++ {
   157  		lambda := values[i]
   158  		vector := make([]float64, n)
   159  		ans2 := make([]float64, n)
   160  		for j := range vector {
   161  			v := V.Data[j*V.Stride+i]
   162  			vector[j] = v
   163  			ans2[j] = lambda * v
   164  		}
   165  		v := blas64.Vector{Inc: 1, Data: vector}
   166  		ans1 := blas64.Vector{Inc: 1, Data: make([]float64, n)}
   167  		blas64.Gemv(blas.NoTrans, 1, A, v, 0, ans1)
   168  		if !floats.EqualApprox(ans1.Data, ans2, 1e-8) {
   169  			return false
   170  		}
   171  	}
   172  	return true
   173  }