github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgehd2.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "math/rand" 11 "testing" 12 13 "github.com/gonum/blas" 14 "github.com/gonum/blas/blas64" 15 ) 16 17 type Dgehd2er interface { 18 Dgehd2(n, ilo, ihi int, a []float64, lda int, tau, work []float64) 19 } 20 21 func Dgehd2Test(t *testing.T, impl Dgehd2er) { 22 rnd := rand.New(rand.NewSource(1)) 23 for _, n := range []int{1, 2, 3, 4, 5, 7, 10, 30} { 24 for _, extra := range []int{0, 1, 13} { 25 for cas := 0; cas < 100; cas++ { 26 testDgehd2(t, impl, n, extra, rnd) 27 } 28 } 29 } 30 } 31 32 func testDgehd2(t *testing.T, impl Dgehd2er, n, extra int, rnd *rand.Rand) { 33 ilo := rnd.Intn(n) 34 ihi := rnd.Intn(n) 35 if ilo > ihi { 36 ilo, ihi = ihi, ilo 37 } 38 39 tau := nanSlice(n - 1) 40 work := nanSlice(n) 41 42 a := randomGeneral(n, n, n+extra, rnd) 43 // NaN out elements under the diagonal except 44 // for the [ilo:ihi,ilo:ihi] block. 45 for i := 1; i <= ihi; i++ { 46 for j := 0; j < min(ilo, i); j++ { 47 a.Data[i*a.Stride+j] = math.NaN() 48 } 49 } 50 for i := ihi + 1; i < n; i++ { 51 for j := 0; j < i; j++ { 52 a.Data[i*a.Stride+j] = math.NaN() 53 } 54 } 55 aCopy := a 56 aCopy.Data = make([]float64, len(a.Data)) 57 copy(aCopy.Data, a.Data) 58 59 impl.Dgehd2(n, ilo, ihi, a.Data, a.Stride, tau, work) 60 61 prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v", n, ilo, ihi, extra) 62 63 // Check any invalid modifications of a. 64 if !generalOutsideAllNaN(a) { 65 t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data) 66 } 67 for i := ilo; i <= ihi; i++ { 68 for j := 0; j < min(ilo, i); j++ { 69 if !math.IsNaN(a.Data[i*a.Stride+j]) { 70 t.Errorf("%v: expected NaN at A[%v,%v]", prefix, i, j) 71 } 72 } 73 } 74 for i := ihi + 1; i < n; i++ { 75 for j := 0; j < i; j++ { 76 if !math.IsNaN(a.Data[i*a.Stride+j]) { 77 t.Errorf("%v: expected NaN at A[%v,%v]", prefix, i, j) 78 } 79 } 80 } 81 for i := 0; i <= ilo; i++ { 82 for j := i; j < ilo+1; j++ { 83 if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] { 84 t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j) 85 } 86 } 87 for j := ihi + 1; j < n; j++ { 88 if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] { 89 t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j) 90 } 91 } 92 } 93 for i := ihi + 1; i < n; i++ { 94 for j := i; j < n; j++ { 95 if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] { 96 t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j) 97 } 98 } 99 } 100 101 // Check that tau has been assigned properly. 102 for i, v := range tau { 103 if i < ilo || i >= ihi { 104 if !math.IsNaN(v) { 105 t.Errorf("%v: expected NaN at tau[%v]", prefix, i) 106 } 107 } else { 108 if math.IsNaN(v) { 109 t.Errorf("%v: unexpected NaN at tau[%v]", prefix, i) 110 } 111 } 112 } 113 114 // Extract Q and check that it is orthogonal. 115 q := blas64.General{ 116 Rows: n, 117 Cols: n, 118 Stride: n, 119 Data: make([]float64, n*n), 120 } 121 for i := 0; i < q.Rows; i++ { 122 q.Data[i*q.Stride+i] = 1 123 } 124 qCopy := q 125 qCopy.Data = make([]float64, len(q.Data)) 126 for j := ilo; j < ihi; j++ { 127 h := blas64.General{ 128 Rows: n, 129 Cols: n, 130 Stride: n, 131 Data: make([]float64, n*n), 132 } 133 for i := 0; i < h.Rows; i++ { 134 h.Data[i*h.Stride+i] = 1 135 } 136 v := blas64.Vector{ 137 Inc: 1, 138 Data: make([]float64, n), 139 } 140 v.Data[j+1] = 1 141 for i := j + 2; i < ihi+1; i++ { 142 v.Data[i] = a.Data[i*a.Stride+j] 143 } 144 blas64.Ger(-tau[j], v, v, h) 145 copy(qCopy.Data, q.Data) 146 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q) 147 } 148 if !isOrthonormal(q) { 149 t.Errorf("%v: Q is not orthogonal\nQ=%v", prefix, q) 150 } 151 152 // Overwrite NaN elements of aCopy with zeros 153 // (we will multiply with it below). 154 for i := 1; i <= ihi; i++ { 155 for j := 0; j < min(ilo, i); j++ { 156 aCopy.Data[i*aCopy.Stride+j] = 0 157 } 158 } 159 for i := ihi + 1; i < n; i++ { 160 for j := 0; j < i; j++ { 161 aCopy.Data[i*aCopy.Stride+j] = 0 162 } 163 } 164 165 // Construct Q^T * AOrig * Q and check that it is 166 // equal to A from Dgehd2. 167 aq := blas64.General{ 168 Rows: n, 169 Cols: n, 170 Stride: n, 171 Data: make([]float64, n*n), 172 } 173 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, q, 0, aq) 174 qaq := blas64.General{ 175 Rows: n, 176 Cols: n, 177 Stride: n, 178 Data: make([]float64, n*n), 179 } 180 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aq, 0, qaq) 181 for i := ilo; i <= ihi; i++ { 182 for j := ilo; j <= ihi; j++ { 183 qaqij := qaq.Data[i*qaq.Stride+j] 184 if j < i-1 { 185 if math.Abs(qaqij) > 1e-14 { 186 t.Errorf("%v: Q^T*A*Q is not upper Hessenberg, [%v,%v]=%v", prefix, i, j, qaqij) 187 } 188 continue 189 } 190 diff := qaqij - a.Data[i*a.Stride+j] 191 if math.Abs(diff) > 1e-14 { 192 t.Errorf("%v: Q^T*AOrig*Q and A are not equal, diff at [%v,%v]=%v", prefix, i, j, diff) 193 } 194 } 195 } 196 }