github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsteqr.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  	"github.com/gonum/lapack"
    15  )
    16  
    17  type Dsteqrer interface {
    18  	Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool)
    19  	Dorgtrer
    20  }
    21  
    22  func DsteqrTest(t *testing.T, impl Dsteqrer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, compz := range []lapack.EVComp{lapack.OriginalEV, lapack.TridiagEV} {
    25  		for _, test := range []struct {
    26  			n, lda int
    27  		}{
    28  			{1, 0},
    29  			{4, 0},
    30  			{8, 0},
    31  			{10, 0},
    32  
    33  			{2, 10},
    34  			{8, 10},
    35  			{10, 20},
    36  		} {
    37  			for cas := 0; cas < 100; cas++ {
    38  				n := test.n
    39  				lda := test.lda
    40  				if lda == 0 {
    41  					lda = n
    42  				}
    43  				d := make([]float64, n)
    44  				for i := range d {
    45  					d[i] = rnd.Float64()
    46  				}
    47  				e := make([]float64, n-1)
    48  				for i := range e {
    49  					e[i] = rnd.Float64()
    50  				}
    51  				a := make([]float64, n*lda)
    52  				for i := range a {
    53  					a[i] = rnd.Float64()
    54  				}
    55  				dCopy := make([]float64, len(d))
    56  				copy(dCopy, d)
    57  				eCopy := make([]float64, len(e))
    58  				copy(eCopy, e)
    59  				aCopy := make([]float64, len(a))
    60  				copy(aCopy, a)
    61  				if compz == lapack.OriginalEV {
    62  					// Compute triangular decomposition and orthonormal matrix.
    63  					uplo := blas.Upper
    64  					tau := make([]float64, n)
    65  					work := make([]float64, 1)
    66  					impl.Dsytrd(blas.Upper, n, a, lda, d, e, tau, work, -1)
    67  					work = make([]float64, int(work[0]))
    68  					impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work))
    69  					impl.Dorgtr(uplo, n, a, lda, tau, work, len(work))
    70  				} else {
    71  					for i := 0; i < n; i++ {
    72  						for j := 0; j < n; j++ {
    73  							a[i*lda+j] = 0
    74  							if i == j {
    75  								a[i*lda+j] = 1
    76  							}
    77  						}
    78  					}
    79  				}
    80  				work := make([]float64, 2*n)
    81  
    82  				aDecomp := make([]float64, len(a))
    83  				copy(aDecomp, a)
    84  				dDecomp := make([]float64, len(d))
    85  				copy(dDecomp, d)
    86  				eDecomp := make([]float64, len(e))
    87  				copy(eDecomp, e)
    88  				impl.Dsteqr(compz, n, d, e, a, lda, work)
    89  				dAns := make([]float64, len(d))
    90  				copy(dAns, d)
    91  
    92  				var truth blas64.General
    93  				if compz == lapack.OriginalEV {
    94  					truth = blas64.General{
    95  						Rows:   n,
    96  						Cols:   n,
    97  						Stride: n,
    98  						Data:   make([]float64, n*n),
    99  					}
   100  					for i := 0; i < n; i++ {
   101  						for j := i; j < n; j++ {
   102  							v := aCopy[i*lda+j]
   103  							truth.Data[i*truth.Stride+j] = v
   104  							truth.Data[j*truth.Stride+i] = v
   105  						}
   106  					}
   107  				} else {
   108  					truth = blas64.General{
   109  						Rows:   n,
   110  						Cols:   n,
   111  						Stride: n,
   112  						Data:   make([]float64, n*n),
   113  					}
   114  					for i := 0; i < n; i++ {
   115  						truth.Data[i*truth.Stride+i] = dCopy[i]
   116  						if i != n-1 {
   117  							truth.Data[(i+1)*truth.Stride+i] = eCopy[i]
   118  							truth.Data[i*truth.Stride+i+1] = eCopy[i]
   119  						}
   120  					}
   121  				}
   122  
   123  				V := blas64.General{
   124  					Rows:   n,
   125  					Cols:   n,
   126  					Stride: lda,
   127  					Data:   a,
   128  				}
   129  				if !eigenDecompCorrect(d, truth, V) {
   130  					t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v",
   131  						compz == lapack.OriginalEV, n)
   132  				}
   133  
   134  				// Compare eigenvalues when not computing eigenvectors.
   135  				for i := range work {
   136  					work[i] = rnd.Float64()
   137  				}
   138  				impl.Dsteqr(lapack.None, n, dDecomp, eDecomp, aDecomp, lda, work)
   139  				if !floats.EqualApprox(d, dAns, 1e-8) {
   140  					t.Errorf("Eigenvalue mismatch when eigenvectors not computed")
   141  				}
   142  			}
   143  		}
   144  	}
   145  }
   146  
   147  // eigenDecompCorrect returns whether the eigen decomposition is correct.
   148  // It checks if
   149  //  A * v ≈ λ * v
   150  // where the eigenvalues λ are stored in values, and the eigenvectors are stored
   151  // in the columns of v.
   152  func eigenDecompCorrect(values []float64, A, V blas64.General) bool {
   153  	n := A.Rows
   154  	for i := 0; i < n; i++ {
   155  		lambda := values[i]
   156  		vector := make([]float64, n)
   157  		ans2 := make([]float64, n)
   158  		for j := range vector {
   159  			v := V.Data[j*V.Stride+i]
   160  			vector[j] = v
   161  			ans2[j] = lambda * v
   162  		}
   163  		v := blas64.Vector{Inc: 1, Data: vector}
   164  		ans1 := blas64.Vector{Inc: 1, Data: make([]float64, n)}
   165  		blas64.Gemv(blas.NoTrans, 1, A, v, 0, ans1)
   166  		if !floats.EqualApprox(ans1.Data, ans2, 1e-8) {
   167  			return false
   168  		}
   169  	}
   170  	return true
   171  }