github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsteqr.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/floats" 14 "github.com/gonum/lapack" 15 ) 16 17 type Dsteqrer interface { 18 Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool) 19 Dorgtrer 20 } 21 22 func DsteqrTest(t *testing.T, impl Dsteqrer) { 23 rnd := rand.New(rand.NewSource(1)) 24 for _, compz := range []lapack.EVComp{lapack.OriginalEV, lapack.TridiagEV} { 25 for _, test := range []struct { 26 n, lda int 27 }{ 28 {1, 0}, 29 {4, 0}, 30 {8, 0}, 31 {10, 0}, 32 33 {2, 10}, 34 {8, 10}, 35 {10, 20}, 36 } { 37 for cas := 0; cas < 100; cas++ { 38 n := test.n 39 lda := test.lda 40 if lda == 0 { 41 lda = n 42 } 43 d := make([]float64, n) 44 for i := range d { 45 d[i] = rnd.Float64() 46 } 47 e := make([]float64, n-1) 48 for i := range e { 49 e[i] = rnd.Float64() 50 } 51 a := make([]float64, n*lda) 52 for i := range a { 53 a[i] = rnd.Float64() 54 } 55 dCopy := make([]float64, len(d)) 56 copy(dCopy, d) 57 eCopy := make([]float64, len(e)) 58 copy(eCopy, e) 59 aCopy := make([]float64, len(a)) 60 copy(aCopy, a) 61 if compz == lapack.OriginalEV { 62 // Compute triangular decomposition and orthonormal matrix. 63 uplo := blas.Upper 64 tau := make([]float64, n) 65 work := make([]float64, 1) 66 impl.Dsytrd(blas.Upper, n, a, lda, d, e, tau, work, -1) 67 work = make([]float64, int(work[0])) 68 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work)) 69 impl.Dorgtr(uplo, n, a, lda, tau, work, len(work)) 70 } else { 71 for i := 0; i < n; i++ { 72 for j := 0; j < n; j++ { 73 a[i*lda+j] = 0 74 if i == j { 75 a[i*lda+j] = 1 76 } 77 } 78 } 79 } 80 work := make([]float64, 2*n) 81 82 aDecomp := make([]float64, len(a)) 83 copy(aDecomp, a) 84 dDecomp := make([]float64, len(d)) 85 copy(dDecomp, d) 86 eDecomp := make([]float64, len(e)) 87 copy(eDecomp, e) 88 impl.Dsteqr(compz, n, d, e, a, lda, work) 89 dAns := make([]float64, len(d)) 90 copy(dAns, d) 91 92 var truth blas64.General 93 if compz == lapack.OriginalEV { 94 truth = blas64.General{ 95 Rows: n, 96 Cols: n, 97 Stride: n, 98 Data: make([]float64, n*n), 99 } 100 for i := 0; i < n; i++ { 101 for j := i; j < n; j++ { 102 v := aCopy[i*lda+j] 103 truth.Data[i*truth.Stride+j] = v 104 truth.Data[j*truth.Stride+i] = v 105 } 106 } 107 } else { 108 truth = blas64.General{ 109 Rows: n, 110 Cols: n, 111 Stride: n, 112 Data: make([]float64, n*n), 113 } 114 for i := 0; i < n; i++ { 115 truth.Data[i*truth.Stride+i] = dCopy[i] 116 if i != n-1 { 117 truth.Data[(i+1)*truth.Stride+i] = eCopy[i] 118 truth.Data[i*truth.Stride+i+1] = eCopy[i] 119 } 120 } 121 } 122 123 V := blas64.General{ 124 Rows: n, 125 Cols: n, 126 Stride: lda, 127 Data: a, 128 } 129 if !eigenDecompCorrect(d, truth, V) { 130 t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v", 131 compz == lapack.OriginalEV, n) 132 } 133 134 // Compare eigenvalues when not computing eigenvectors. 135 for i := range work { 136 work[i] = rnd.Float64() 137 } 138 impl.Dsteqr(lapack.None, n, dDecomp, eDecomp, aDecomp, lda, work) 139 if !floats.EqualApprox(d, dAns, 1e-8) { 140 t.Errorf("Eigenvalue mismatch when eigenvectors not computed") 141 } 142 } 143 } 144 } 145 } 146 147 // eigenDecompCorrect returns whether the eigen decomposition is correct. 148 // It checks if 149 // A * v ≈ λ * v 150 // where the eigenvalues λ are stored in values, and the eigenvectors are stored 151 // in the columns of v. 152 func eigenDecompCorrect(values []float64, A, V blas64.General) bool { 153 n := A.Rows 154 for i := 0; i < n; i++ { 155 lambda := values[i] 156 vector := make([]float64, n) 157 ans2 := make([]float64, n) 158 for j := range vector { 159 v := V.Data[j*V.Stride+i] 160 vector[j] = v 161 ans2[j] = lambda * v 162 } 163 v := blas64.Vector{Inc: 1, Data: vector} 164 ans1 := blas64.Vector{Inc: 1, Data: make([]float64, n)} 165 blas64.Gemv(blas.NoTrans, 1, A, v, 0, ans1) 166 if !floats.EqualApprox(ans1.Data, ans2, 1e-8) { 167 return false 168 } 169 } 170 return true 171 }