github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsytrd.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 ) 15 16 type Dsytrder interface { 17 Dsytrd(uplo blas.Uplo, n int, a []float64, lda int, d, e, tau, work []float64, lwork int) 18 19 Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) 20 Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int) 21 } 22 23 func DsytrdTest(t *testing.T, impl Dsytrder) { 24 const tol = 1e-13 25 rnd := rand.New(rand.NewSource(1)) 26 for tc, test := range []struct { 27 n, lda int 28 }{ 29 {1, 0}, 30 {2, 0}, 31 {3, 0}, 32 {4, 0}, 33 {10, 0}, 34 {50, 0}, 35 {100, 0}, 36 {150, 0}, 37 {300, 0}, 38 39 {1, 3}, 40 {2, 3}, 41 {3, 7}, 42 {4, 9}, 43 {10, 20}, 44 {50, 70}, 45 {100, 120}, 46 {150, 170}, 47 {300, 320}, 48 } { 49 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 50 for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} { 51 n := test.n 52 lda := test.lda 53 if lda == 0 { 54 lda = n 55 } 56 a := randomGeneral(n, n, lda, rnd) 57 for i := 1; i < n; i++ { 58 for j := 0; j < i; j++ { 59 a.Data[i*a.Stride+j] = a.Data[j*a.Stride+i] 60 } 61 } 62 aCopy := cloneGeneral(a) 63 64 d := nanSlice(n) 65 e := nanSlice(n - 1) 66 tau := nanSlice(n - 1) 67 68 var lwork int 69 switch wl { 70 case minimumWork: 71 lwork = 1 72 case mediumWork: 73 work := make([]float64, 1) 74 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1) 75 lwork = (int(work[0]) + 1) / 2 76 lwork = max(1, lwork) 77 case optimumWork: 78 work := make([]float64, 1) 79 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, -1) 80 lwork = int(work[0]) 81 } 82 work := make([]float64, lwork) 83 84 impl.Dsytrd(uplo, n, a.Data, a.Stride, d, e, tau, work, lwork) 85 86 prefix := fmt.Sprintf("Case #%v: uplo=%v,n=%v,lda=%v,work=%v", 87 tc, uplo, n, lda, wl) 88 89 if !generalOutsideAllNaN(a) { 90 t.Errorf("%v: out-of-range write to A", prefix) 91 } 92 93 // Extract Q by doing what Dorgtr does. 94 q := cloneGeneral(a) 95 if uplo == blas.Upper { 96 for j := 0; j < n-1; j++ { 97 for i := 0; i < j; i++ { 98 q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j+1] 99 } 100 q.Data[(n-1)*q.Stride+j] = 0 101 } 102 for i := 0; i < n-1; i++ { 103 q.Data[i*q.Stride+n-1] = 0 104 } 105 q.Data[(n-1)*q.Stride+n-1] = 1 106 if n > 1 { 107 work = make([]float64, n-1) 108 impl.Dorgql(n-1, n-1, n-1, q.Data, q.Stride, tau, work, len(work)) 109 } 110 } else { 111 for j := n - 1; j > 0; j-- { 112 q.Data[j] = 0 113 for i := j + 1; i < n; i++ { 114 q.Data[i*q.Stride+j] = q.Data[i*q.Stride+j-1] 115 } 116 } 117 q.Data[0] = 1 118 for i := 1; i < n; i++ { 119 q.Data[i*q.Stride] = 0 120 } 121 if n > 1 { 122 work = make([]float64, n-1) 123 impl.Dorgqr(n-1, n-1, n-1, q.Data[q.Stride+1:], q.Stride, tau, work, len(work)) 124 } 125 } 126 if !isOrthonormal(q) { 127 t.Errorf("%v: Q not orthogonal", prefix) 128 } 129 130 // Contruct symmetric tridiagonal T from d and e. 131 tMat := zeros(n, n, n) 132 for i := 0; i < n; i++ { 133 tMat.Data[i*tMat.Stride+i] = d[i] 134 } 135 if uplo == blas.Upper { 136 for j := 1; j < n; j++ { 137 tMat.Data[(j-1)*tMat.Stride+j] = e[j-1] 138 tMat.Data[j*tMat.Stride+j-1] = e[j-1] 139 } 140 } else { 141 for j := 0; j < n-1; j++ { 142 tMat.Data[(j+1)*tMat.Stride+j] = e[j] 143 tMat.Data[j*tMat.Stride+j+1] = e[j] 144 } 145 } 146 147 // Compute Q^T * A * Q. 148 tmp := zeros(n, n, n) 149 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aCopy, 0, tmp) 150 got := zeros(n, n, n) 151 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, got) 152 153 // Compare with T. 154 if !equalApproxGeneral(got, tMat, tol) { 155 t.Errorf("%v: Q^T*A*Q != T", prefix) 156 } 157 } 158 } 159 } 160 }