gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dsytrd.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dsytrder interface { 19 Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int) 20 21 Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) 22 Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) 23 } 24 25 func DsytrdTest(t *testing.T, impl Dsytrder) { 26 const tol = 1e-14 27 28 rnd := rand.New(rand.NewSource(1)) 29 for tc, test := range []struct { 30 n, lda int 31 }{ 32 {1, 0}, 33 {2, 0}, 34 {3, 0}, 35 {4, 0}, 36 {10, 0}, 37 {50, 0}, 38 {100, 0}, 39 {150, 0}, 40 {300, 0}, 41 42 {1, 3}, 43 {2, 3}, 44 {3, 7}, 45 {4, 9}, 46 {10, 20}, 47 {50, 70}, 48 {100, 120}, 49 {150, 170}, 50 {300, 320}, 51 } { 52 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 53 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { 54 n := test.n 55 lda := test.lda 56 if lda == 0 { 57 lda = n 58 } 59 a := randomGeneral(n, n, lda, rnd) 60 for i := 1; i < n; i++ { 61 for j := 0; j < i; j++ { 62 a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i] 63 } 64 } 65 aCopy := cloneGeneral(a) 66 67 d := nanSlice(n) 68 e := nanSlice(n - 1) 69 tau := nanSlice(n - 1) 70 71 var lwork int 72 switch wl { 73 case minimumWork: 74 lwork = 1 75 case mediumWork: 76 work := make([]float64, 1) 77 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1) 78 lwork = (int(work[0]) + 1) / 2 79 lwork = max(1, lwork) 80 case optimumWork: 81 work := make([]float64, 1) 82 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1) 83 lwork = int(work[0]) 84 } 85 work := make([]float64, lwork) 86 87 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork) 88 89 prefix := fmt.Sprintf("Case #%v: uplo=%c,n=%v,lda=%v,work=%v", 90 tc, uplo, n, lda, wl) 91 92 if !generalOutsideAllNaN(a) { 93 t.Errorf("%v: out-of-range write to A", prefix) 94 } 95 96 // Extract Q by doing what Dorgtr does. 97 q := cloneGeneral(a) 98 if uplo == blas.Upper { 99 for j := 0; j < n-1; j++ { 100 for i := 0; i < j; i++ { 101 q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1] 102 } 103 q.Data[(n-1)*q.Stride+j] = 0 104 } 105 for i := 0; i < n-1; i++ { 106 q.Data[i*q.Stride+n-1] = 0 107 } 108 q.Data[(n-1)*q.Stride+n-1] = 1 109 if n > 1 { 110 work = make([]float64, n-1) 111 impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work)) 112 } 113 } else { 114 for j := n - 1; j > 0; j-- { 115 q.Data[j] = 0 116 for i := j + 1; i < n; i++ { 117 q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1] 118 } 119 } 120 q.Data[0] = 1 121 for i := 1; i < n; i++ { 122 q.Data[i*q.Stride] = 0 123 } 124 if n > 1 { 125 work = make([]float64, n-1) 126 impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work)) 127 } 128 } 129 if resid := residualOrthogonal(q, false); resid > tol*float64(n) { 130 t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", prefix, resid, tol*float64(n)) 131 } 132 133 // Construct symmetric tridiagonal T from d and e. 134 tMat := zeros(n, n, n) 135 for i := 0; i < n; i++ { 136 tMat.Data[i*tMat.Stride+i] = d[i] 137 } 138 if uplo == blas.Upper { 139 for j := 1; j < n; j++ { 140 tMat.Data[(j-1)*tMat.Stride+j] = e[j-1] 141 tMat.Data[j*tMat.Stride+j-1] = e[j-1] 142 } 143 } else { 144 for j := 0; j < n-1; j++ { 145 tMat.Data[(j+1)*tMat.Stride+j] = e[j] 146 tMat.Data[j*tMat.Stride+j+1] = e[j] 147 } 148 } 149 // Compute Qᵀ*A*Q - T. 150 qa := zeros(n, n, n) 151 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, qa) 152 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qa, q, -1, tMat) 153 // Check that |Qᵀ*A*Q - T| is small. 154 resid := dlange(lapack.MaxColumnSum, n, n, tMat.Data, tMat.Stride) 155 if resid > tol*float64(n) { 156 t.Errorf("%v: |Qᵀ*A*Q - T|=%v, want<=%v", prefix, resid, tol*float64(n)) 157 } 158 } 159 } 160 } 161 }