github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsytd2.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 ) 15 16 type Dsytd2er interface { 17 Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64) 18 } 19 20 func Dsytd2Test(t *testing.T, impl Dsytd2er) { 21 rnd := rand.New(rand.NewSource(1)) 22 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 23 for _, test := range []struct { 24 n, lda int 25 }{ 26 {3, 0}, 27 {4, 0}, 28 {5, 0}, 29 30 {3, 10}, 31 {4, 10}, 32 {5, 10}, 33 } { 34 n := test.n 35 lda := test.lda 36 if lda == 0 { 37 lda = n 38 } 39 a := make([]float64, n*lda) 40 for i := range a { 41 a[i] = rnd.NormFloat64() 42 } 43 aCopy := make([]float64, len(a)) 44 copy(aCopy, a) 45 46 d := make([]float64, n) 47 for i := range d { 48 d[i] = math.NaN() 49 } 50 e := make([]float64, n-1) 51 for i := range e { 52 e[i] = math.NaN() 53 } 54 tau := make([]float64, n-1) 55 for i := range tau { 56 tau[i] = math.NaN() 57 } 58 59 impl.Dsytd2(uplo, n, a, lda, d, e, tau) 60 61 // Construct Q 62 qMat := blas64.General{ 63 Rows: n, 64 Cols: n, 65 Stride: n, 66 Data: make([]float64, n*n), 67 } 68 qCopy := blas64.General{ 69 Rows: n, 70 Cols: n, 71 Stride: n, 72 Data: make([]float64, len(qMat.Data)), 73 } 74 // Set Q to I. 75 for i := 0; i < n; i++ { 76 qMat.Data[i*qMat.Stride+i] = 1 77 } 78 for i := 0; i < n-1; i++ { 79 hMat := blas64.General{ 80 Rows: n, 81 Cols: n, 82 Stride: n, 83 Data: make([]float64, n*n), 84 } 85 // Set H to I. 86 for i := 0; i < n; i++ { 87 hMat.Data[i*hMat.Stride+i] = 1 88 } 89 var vi blas64.Vector 90 if uplo == blas.Upper { 91 vi = blas64.Vector{ 92 Inc: 1, 93 Data: make([]float64, n), 94 } 95 for j := 0; j < i; j++ { 96 vi.Data[j] = a[j*lda+i+1] 97 } 98 vi.Data[i] = 1 99 } else { 100 vi = blas64.Vector{ 101 Inc: 1, 102 Data: make([]float64, n), 103 } 104 vi.Data[i+1] = 1 105 for j := i + 2; j < n; j++ { 106 vi.Data[j] = a[j*lda+i] 107 } 108 } 109 blas64.Ger(-tau[i], vi, vi, hMat) 110 copy(qCopy.Data, qMat.Data) 111 112 // Multiply q by the new h. 113 if uplo == blas.Upper { 114 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, qCopy, 0, qMat) 115 } else { 116 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat) 117 } 118 } 119 // Check that Q is orthonormal 120 othonormal := true 121 for i := 0; i < n; i++ { 122 for j := i; j < n; j++ { 123 dot := blas64.Dot(n, 124 blas64.Vector{Inc: 1, Data: qMat.Data[i*qMat.Stride:]}, 125 blas64.Vector{Inc: 1, Data: qMat.Data[j*qMat.Stride:]}, 126 ) 127 if i == j { 128 if math.Abs(dot-1) > 1e-10 { 129 othonormal = false 130 } 131 } else { 132 if math.Abs(dot) > 1e-10 { 133 othonormal = false 134 } 135 } 136 } 137 } 138 if !othonormal { 139 t.Errorf("Q not orthonormal") 140 } 141 142 // Compute Q^T * A * Q. 143 aMat := blas64.General{ 144 Rows: n, 145 Cols: n, 146 Stride: n, 147 Data: make([]float64, len(a)), 148 } 149 150 for i := 0; i < n; i++ { 151 for j := i; j < n; j++ { 152 v := aCopy[i*lda+j] 153 if uplo == blas.Lower { 154 v = aCopy[j*lda+i] 155 } 156 aMat.Data[i*aMat.Stride+j] = v 157 aMat.Data[j*aMat.Stride+i] = v 158 } 159 } 160 161 tmp := blas64.General{ 162 Rows: n, 163 Cols: n, 164 Stride: n, 165 Data: make([]float64, n*n), 166 } 167 168 ans := blas64.General{ 169 Rows: n, 170 Cols: n, 171 Stride: n, 172 Data: make([]float64, n*n), 173 } 174 175 blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp) 176 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, qMat, 0, ans) 177 178 // Compare with T. 179 tMat := blas64.General{ 180 Rows: n, 181 Cols: n, 182 Stride: n, 183 Data: make([]float64, n*n), 184 } 185 for i := 0; i < n-1; i++ { 186 tMat.Data[i*tMat.Stride+i] = d[i] 187 tMat.Data[i*tMat.Stride+i+1] = e[i] 188 tMat.Data[(i+1)*tMat.Stride+i] = e[i] 189 } 190 tMat.Data[(n-1)*tMat.Stride+n-1] = d[n-1] 191 192 same := true 193 for i := 0; i < n; i++ { 194 for j := 0; j < n; j++ { 195 if math.Abs(ans.Data[i*ans.Stride+j]-tMat.Data[i*tMat.Stride+j]) > 1e-10 { 196 same = false 197 } 198 } 199 } 200 if !same { 201 t.Errorf("Matrix answer mismatch") 202 } 203 } 204 } 205 }