github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlatrd.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "math/rand" 11 "testing" 12 13 "github.com/gonum/blas" 14 "github.com/gonum/blas/blas64" 15 ) 16 17 type Dlatrder interface { 18 Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int) 19 } 20 21 func DlatrdTest(t *testing.T, impl Dlatrder) { 22 rnd := rand.New(rand.NewSource(1)) 23 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 24 for _, test := range []struct { 25 n, nb, lda, ldw int 26 }{ 27 {5, 2, 0, 0}, 28 {5, 5, 0, 0}, 29 30 {5, 3, 10, 11}, 31 {5, 5, 10, 11}, 32 } { 33 n := test.n 34 nb := test.nb 35 lda := test.lda 36 if lda == 0 { 37 lda = n 38 } 39 ldw := test.ldw 40 if ldw == 0 { 41 ldw = nb 42 } 43 44 a := make([]float64, n*lda) 45 for i := range a { 46 a[i] = rnd.NormFloat64() 47 } 48 49 e := make([]float64, n-1) 50 for i := range e { 51 e[i] = math.NaN() 52 } 53 tau := make([]float64, n-1) 54 for i := range tau { 55 tau[i] = math.NaN() 56 } 57 w := make([]float64, n*ldw) 58 for i := range w { 59 w[i] = math.NaN() 60 } 61 62 aCopy := make([]float64, len(a)) 63 copy(aCopy, a) 64 65 impl.Dlatrd(uplo, n, nb, a, lda, e, tau, w, ldw) 66 67 // Construct Q. 68 ldq := n 69 q := blas64.General{ 70 Rows: n, 71 Cols: n, 72 Stride: ldq, 73 Data: make([]float64, n*ldq), 74 } 75 for i := 0; i < n; i++ { 76 q.Data[i*ldq+i] = 1 77 } 78 if uplo == blas.Upper { 79 for i := n - 1; i >= n-nb; i-- { 80 if i == 0 { 81 continue 82 } 83 h := blas64.General{ 84 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), 85 } 86 for j := 0; j < n; j++ { 87 h.Data[j*n+j] = 1 88 } 89 v := blas64.Vector{ 90 Inc: 1, 91 Data: make([]float64, n), 92 } 93 for j := 0; j < i-1; j++ { 94 v.Data[j] = a[j*lda+i] 95 } 96 v.Data[i-1] = 1 97 98 blas64.Ger(-tau[i-1], v, v, h) 99 100 qTmp := blas64.General{ 101 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), 102 } 103 copy(qTmp.Data, q.Data) 104 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q) 105 } 106 } else { 107 for i := 0; i < nb; i++ { 108 if i == n-1 { 109 continue 110 } 111 h := blas64.General{ 112 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), 113 } 114 for j := 0; j < n; j++ { 115 h.Data[j*n+j] = 1 116 } 117 v := blas64.Vector{ 118 Inc: 1, 119 Data: make([]float64, n), 120 } 121 v.Data[i+1] = 1 122 for j := i + 2; j < n; j++ { 123 v.Data[j] = a[j*lda+i] 124 } 125 blas64.Ger(-tau[i], v, v, h) 126 127 qTmp := blas64.General{ 128 Rows: n, Cols: n, Stride: n, Data: make([]float64, n*n), 129 } 130 copy(qTmp.Data, q.Data) 131 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qTmp, h, 0, q) 132 } 133 } 134 errStr := fmt.Sprintf("isUpper = %v, n = %v, nb = %v", uplo == blas.Upper, n, nb) 135 if !isOrthonormal(q) { 136 t.Errorf("Q not orthonormal. %s", errStr) 137 } 138 aGen := genFromSym(blas64.Symmetric{N: n, Stride: lda, Uplo: uplo, Data: aCopy}) 139 if !dlatrdCheckDecomposition(t, uplo, n, nb, e, tau, a, lda, aGen, q) { 140 t.Errorf("Decomposition mismatch. %s", errStr) 141 } 142 } 143 } 144 } 145 146 // dlatrdCheckDecomposition checks that the first nb rows have been successfully 147 // reduced. 148 func dlatrdCheckDecomposition(t *testing.T, uplo blas.Uplo, n, nb int, e, tau, a []float64, lda int, aGen, q blas64.General) bool { 149 // Compute Q^T * A * Q. 150 tmp := blas64.General{ 151 Rows: n, 152 Cols: n, 153 Stride: n, 154 Data: make([]float64, n*n), 155 } 156 157 ans := blas64.General{ 158 Rows: n, 159 Cols: n, 160 Stride: n, 161 Data: make([]float64, n*n), 162 } 163 164 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aGen, 0, tmp) 165 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, q, 0, ans) 166 167 // Compare with T. 168 if uplo == blas.Upper { 169 for i := n - 1; i >= n-nb; i-- { 170 for j := 0; j < n; j++ { 171 v := ans.Data[i*ans.Stride+j] 172 switch { 173 case i == j: 174 if math.Abs(v-a[i*lda+j]) > 1e-10 { 175 return false 176 } 177 case i == j-1: 178 if math.Abs(a[i*lda+j]-1) > 1e-10 { 179 return false 180 } 181 if math.Abs(v-e[i]) > 1e-10 { 182 return false 183 } 184 case i == j+1: 185 default: 186 if math.Abs(v) > 1e-10 { 187 return false 188 } 189 } 190 } 191 } 192 } else { 193 for i := 0; i < nb; i++ { 194 for j := 0; j < n; j++ { 195 v := ans.Data[i*ans.Stride+j] 196 switch { 197 case i == j: 198 if math.Abs(v-a[i*lda+j]) > 1e-10 { 199 return false 200 } 201 case i == j-1: 202 case i == j+1: 203 if math.Abs(a[i*lda+j]-1) > 1e-10 { 204 return false 205 } 206 if math.Abs(v-e[i-1]) > 1e-10 { 207 return false 208 } 209 default: 210 if math.Abs(v) > 1e-10 { 211 return false 212 } 213 } 214 } 215 } 216 } 217 return true 218 } 219 220 // genFromSym constructs a (symmetric) general matrix from the data in the 221 // symmetric. 222 // TODO(btracey): Replace other constructions of this with a call to this function. 223 func genFromSym(a blas64.Symmetric) blas64.General { 224 n := a.N 225 lda := a.Stride 226 uplo := a.Uplo 227 b := blas64.General{ 228 Rows: n, 229 Cols: n, 230 Stride: n, 231 Data: make([]float64, n*n), 232 } 233 234 for i := 0; i < n; i++ { 235 for j := i; j < n; j++ { 236 v := a.Data[i*lda+j] 237 if uplo == blas.Lower { 238 v = a.Data[j*lda+i] 239 } 240 b.Data[i*n+j] = v 241 b.Data[j*n+i] = v 242 } 243 } 244 return b 245 }