gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dsytd2.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 ) 16 17 type Dsytd2er interface { 18 Dsytd2(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau []float64) 19 } 20 21 func Dsytd2Test(t *testing.T, impl Dsytd2er) { 22 const tol = 1e-14 23 24 rnd := rand.New(rand.NewSource(1)) 25 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 26 for _, test := range []struct { 27 n, lda int 28 }{ 29 {3, 0}, 30 {4, 0}, 31 {5, 0}, 32 33 {3, 10}, 34 {4, 10}, 35 {5, 10}, 36 } { 37 n := test.n 38 lda := test.lda 39 if lda == 0 { 40 lda = n 41 } 42 a := make([]float64, n*lda) 43 for i := range a { 44 a[i] = rnd.NormFloat64() 45 } 46 aCopy := make([]float64, len(a)) 47 copy(aCopy, a) 48 49 d := make([]float64, n) 50 for i := range d { 51 d[i] = math.NaN() 52 } 53 e := make([]float64, n-1) 54 for i := range e { 55 e[i] = math.NaN() 56 } 57 tau := make([]float64, n-1) 58 for i := range tau { 59 tau[i] = math.NaN() 60 } 61 62 impl.Dsytd2(uplo, n, a, lda, d, e, tau) 63 64 // Construct Q 65 qMat := blas64.General{ 66 Rows: n, 67 Cols: n, 68 Stride: n, 69 Data: make([]float64, n*n), 70 } 71 qCopy := blas64.General{ 72 Rows: n, 73 Cols: n, 74 Stride: n, 75 Data: make([]float64, len(qMat.Data)), 76 } 77 // Set Q to I. 78 for i := 0; i < n; i++ { 79 qMat.Data[i*qMat.Stride+i] = 1 80 } 81 for i := 0; i < n-1; i++ { 82 hMat := blas64.General{ 83 Rows: n, 84 Cols: n, 85 Stride: n, 86 Data: make([]float64, n*n), 87 } 88 // Set H to I. 89 for i := 0; i < n; i++ { 90 hMat.Data[i*hMat.Stride+i] = 1 91 } 92 var vi blas64.Vector 93 if uplo == blas.Upper { 94 vi = blas64.Vector{ 95 Inc: 1, 96 Data: make([]float64, n), 97 } 98 for j := 0; j < i; j++ { 99 vi.Data[j] = a[j*lda+i+1] 100 } 101 vi.Data[i] = 1 102 } else { 103 vi = blas64.Vector{ 104 Inc: 1, 105 Data: make([]float64, n), 106 } 107 vi.Data[i+1] = 1 108 for j := i + 2; j < n; j++ { 109 vi.Data[j] = a[j*lda+i] 110 } 111 } 112 blas64.Ger(-tau[i], vi, vi, hMat) 113 copy(qCopy.Data, qMat.Data) 114 115 // Multiply q by the new h. 116 if uplo == blas.Upper { 117 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, qCopy, 0, qMat) 118 } else { 119 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, hMat, 0, qMat) 120 } 121 } 122 123 if resid := residualOrthogonal(qMat, false); resid > tol { 124 t.Errorf("Q is not orthogonal; resid=%v, want<=%v", resid, tol) 125 } 126 127 // Compute Qᵀ * A * Q. 128 aMat := blas64.General{ 129 Rows: n, 130 Cols: n, 131 Stride: n, 132 Data: make([]float64, len(a)), 133 } 134 135 for i := 0; i < n; i++ { 136 for j := i; j < n; j++ { 137 v := aCopy[i*lda+j] 138 if uplo == blas.Lower { 139 v = aCopy[j*lda+i] 140 } 141 aMat.Data[i*aMat.Stride+j] = v 142 aMat.Data[j*aMat.Stride+i] = v 143 } 144 } 145 146 tmp := blas64.General{ 147 Rows: n, 148 Cols: n, 149 Stride: n, 150 Data: make([]float64, n*n), 151 } 152 153 ans := blas64.General{ 154 Rows: n, 155 Cols: n, 156 Stride: n, 157 Data: make([]float64, n*n), 158 } 159 160 blas64.Gemm(blas.Trans, blas.NoTrans, 1, qMat, aMat, 0, tmp) 161 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, qMat, 0, ans) 162 163 // Compare with T. 164 tMat := blas64.General{ 165 Rows: n, 166 Cols: n, 167 Stride: n, 168 Data: make([]float64, n*n), 169 } 170 for i := 0; i < n-1; i++ { 171 tMat.Data[i*tMat.Stride+i] = d[i] 172 tMat.Data[i*tMat.Stride+i+1] = e[i] 173 tMat.Data[(i+1)*tMat.Stride+i] = e[i] 174 } 175 tMat.Data[(n-1)*tMat.Stride+n-1] = d[n-1] 176 177 same := true 178 for i := 0; i < n; i++ { 179 for j := 0; j < n; j++ { 180 if math.Abs(ans.Data[i*ans.Stride+j]-tMat.Data[i*tMat.Stride+j]) > tol { 181 same = false 182 } 183 } 184 } 185 if !same { 186 t.Errorf("Matrix answer mismatch") 187 } 188 } 189 } 190 }