gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dsteqr.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dsteqrer interface { 19 Dsteqr(compz lapack.EVComp, n int, d, e, z []float64, ldz int, work []float64) (ok bool) 20 Dorgtrer 21 } 22 23 func DsteqrTest(t *testing.T, impl Dsteqrer) { 24 rnd := rand.New(rand.NewSource(1)) 25 for _, compz := range []lapack.EVComp{lapack.EVOrig, lapack.EVTridiag} { 26 for _, test := range []struct { 27 n, lda int 28 }{ 29 {1, 0}, 30 {4, 0}, 31 {8, 0}, 32 {10, 0}, 33 34 {2, 10}, 35 {8, 10}, 36 {10, 20}, 37 } { 38 for cas := 0; cas < 100; cas++ { 39 n := test.n 40 lda := test.lda 41 if lda == 0 { 42 lda = n 43 } 44 d := make([]float64, n) 45 for i := range d { 46 d[i] = rnd.Float64() 47 } 48 e := make([]float64, n-1) 49 for i := range e { 50 e[i] = rnd.Float64() 51 } 52 a := make([]float64, n*lda) 53 for i := range a { 54 a[i] = rnd.Float64() 55 } 56 dCopy := make([]float64, len(d)) 57 copy(dCopy, d) 58 eCopy := make([]float64, len(e)) 59 copy(eCopy, e) 60 aCopy := make([]float64, len(a)) 61 copy(aCopy, a) 62 if compz == lapack.EVOrig { 63 uplo := blas.Upper 64 tau := make([]float64, n) 65 work := make([]float64, 1) 66 impl.Dsytrd(blas.Upper, n, a, lda, d, e, tau, work, -1) 67 work = make([]float64, int(work[0])) 68 // Reduce A to symmetric tridiagonal form. 69 impl.Dsytrd(uplo, n, a, lda, d, e, tau, work, len(work)) 70 // Compute the orthogonal matrix Q. 71 impl.Dorgtr(uplo, n, a, lda, tau, work, len(work)) 72 } else { 73 for i := 0; i < n; i++ { 74 for j := 0; j < n; j++ { 75 a[i*lda+j] = 0 76 if i == j { 77 a[i*lda+j] = 1 78 } 79 } 80 } 81 } 82 work := make([]float64, 2*n) 83 84 aDecomp := make([]float64, len(a)) 85 copy(aDecomp, a) 86 dDecomp := make([]float64, len(d)) 87 copy(dDecomp, d) 88 eDecomp := make([]float64, len(e)) 89 copy(eDecomp, e) 90 impl.Dsteqr(compz, n, d, e, a, lda, work) 91 dAns := make([]float64, len(d)) 92 copy(dAns, d) 93 94 var truth blas64.General 95 if compz == lapack.EVOrig { 96 truth = blas64.General{ 97 Rows: n, 98 Cols: n, 99 Stride: n, 100 Data: make([]float64, n*n), 101 } 102 for i := 0; i < n; i++ { 103 for j := i; j < n; j++ { 104 v := aCopy[i*lda+j] 105 truth.Data[i*truth.Stride+j] = v 106 truth.Data[j*truth.Stride+i] = v 107 } 108 } 109 } else { 110 truth = blas64.General{ 111 Rows: n, 112 Cols: n, 113 Stride: n, 114 Data: make([]float64, n*n), 115 } 116 for i := 0; i < n; i++ { 117 truth.Data[i*truth.Stride+i] = dCopy[i] 118 if i != n-1 { 119 truth.Data[(i+1)*truth.Stride+i] = eCopy[i] 120 truth.Data[i*truth.Stride+i+1] = eCopy[i] 121 } 122 } 123 } 124 125 V := blas64.General{ 126 Rows: n, 127 Cols: n, 128 Stride: lda, 129 Data: a, 130 } 131 if !eigenDecompCorrect(d, truth, V) { 132 t.Errorf("Eigen reconstruction mismatch. fromFull = %v, n = %v", 133 compz == lapack.EVOrig, n) 134 } 135 136 // Compare eigenvalues when not computing eigenvectors. 137 for i := range work { 138 work[i] = rnd.Float64() 139 } 140 impl.Dsteqr(lapack.EVCompNone, n, dDecomp, eDecomp, aDecomp, lda, work) 141 if !floats.EqualApprox(d, dAns, 1e-8) { 142 t.Errorf("Eigenvalue mismatch when eigenvectors not computed") 143 } 144 } 145 } 146 } 147 } 148 149 // eigenDecompCorrect returns whether the eigen decomposition is correct. 150 // It checks if 151 // 152 // A * v ≈ λ * v 153 // 154 // where the eigenvalues λ are stored in values, and the eigenvectors are stored 155 // in the columns of v. 156 func eigenDecompCorrect(values []float64, A, V blas64.General) bool { 157 n := A.Rows 158 for i := 0; i < n; i++ { 159 lambda := values[i] 160 vector := make([]float64, n) 161 ans2 := make([]float64, n) 162 for j := range vector { 163 v := V.Data[j*V.Stride+i] 164 vector[j] = v 165 ans2[j] = lambda * v 166 } 167 v := blas64.Vector{Inc: 1, Data: vector} 168 ans1 := blas64.Vector{Inc: 1, Data: make([]float64, n)} 169 blas64.Gemv(blas.NoTrans, 1, A, v, 0, ans1) 170 if !floats.EqualApprox(ans1.Data, ans2, 1e-8) { 171 return false 172 } 173 } 174 return true 175 }