github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlahr2.go (about) 1 // Copyright ©2016 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "compress/gzip" 9 "encoding/json" 10 "fmt" 11 "log" 12 "math" 13 "math/rand" 14 "os" 15 "path/filepath" 16 "testing" 17 18 "github.com/gonum/blas" 19 "github.com/gonum/blas/blas64" 20 "github.com/gonum/floats" 21 ) 22 23 type Dlahr2er interface { 24 Dlahr2(n, k, nb int, a []float64, lda int, tau, t []float64, ldt int, y []float64, ldy int) 25 } 26 27 type Dlahr2test struct { 28 N, K, NB int 29 A []float64 30 31 AWant []float64 32 TWant []float64 33 YWant []float64 34 TauWant []float64 35 } 36 37 func Dlahr2Test(t *testing.T, impl Dlahr2er) { 38 rnd := rand.New(rand.NewSource(1)) 39 for _, test := range []struct { 40 n, k, nb int 41 }{ 42 {3, 0, 3}, 43 {3, 1, 2}, 44 {3, 1, 1}, 45 46 {5, 0, 5}, 47 {5, 1, 4}, 48 {5, 1, 3}, 49 {5, 1, 2}, 50 {5, 1, 1}, 51 {5, 2, 3}, 52 {5, 2, 2}, 53 {5, 2, 1}, 54 {5, 3, 2}, 55 {5, 3, 1}, 56 57 {7, 3, 4}, 58 {7, 3, 3}, 59 {7, 3, 2}, 60 {7, 3, 1}, 61 62 {10, 0, 10}, 63 {10, 1, 9}, 64 {10, 1, 5}, 65 {10, 1, 1}, 66 {10, 5, 5}, 67 {10, 5, 3}, 68 {10, 5, 1}, 69 } { 70 for cas := 0; cas < 100; cas++ { 71 for _, extraStride := range []int{0, 1, 10} { 72 n := test.n 73 k := test.k 74 nb := test.nb 75 76 a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd) 77 aCopy := a 78 aCopy.Data = make([]float64, len(a.Data)) 79 copy(aCopy.Data, a.Data) 80 tmat := nanTriangular(blas.Upper, nb, nb+extraStride) 81 y := nanGeneral(n, nb, nb+extraStride) 82 tau := nanSlice(nb) 83 84 impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride) 85 86 prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride) 87 88 if !generalOutsideAllNaN(a) { 89 t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data) 90 } 91 if !triangularOutsideAllNaN(tmat) { 92 t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data) 93 } 94 if !generalOutsideAllNaN(y) { 95 t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data) 96 } 97 98 // Check that A[:k,:] and A[:,nb:] blocks were not modified. 99 for i := 0; i < n; i++ { 100 for j := 0; j < n-k+1; j++ { 101 if i >= k && j < nb { 102 continue 103 } 104 if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] { 105 t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j) 106 } 107 } 108 } 109 110 // Check that all elements of tau were assigned. 111 for i, v := range tau { 112 if math.IsNaN(v) { 113 t.Errorf("%v: tau[%v] not assigned", prefix, i) 114 } 115 } 116 117 // Extract V from a. 118 v := blas64.General{ 119 Rows: n - k + 1, 120 Cols: nb, 121 Stride: nb, 122 Data: make([]float64, (n-k+1)*nb), 123 } 124 for j := 0; j < v.Cols; j++ { 125 v.Data[(j+1)*v.Stride+j] = 1 126 for i := j + 2; i < v.Rows; i++ { 127 v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j] 128 } 129 } 130 131 // VT = V. 132 vt := v 133 vt.Data = make([]float64, len(v.Data)) 134 copy(vt.Data, v.Data) 135 // VT = V * T. 136 blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt) 137 // YWant = A * V * T. 138 ywant := blas64.General{ 139 Rows: n, 140 Cols: nb, 141 Stride: nb, 142 Data: make([]float64, n*nb), 143 } 144 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant) 145 146 // Compare Y and YWant. 147 for i := 0; i < n; i++ { 148 for j := 0; j < nb; j++ { 149 diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j]) 150 if diff > 1e-14 { 151 t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff) 152 } 153 } 154 } 155 156 // Construct Q directly from the first nb columns of a. 157 q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau) 158 if !isOrthonormal(q) { 159 t.Errorf("%v: Q is not orthogonal", prefix) 160 } 161 // Construct Q as the product Q = I - V*T*V^T. 162 qwant := blas64.General{ 163 Rows: n - k + 1, 164 Cols: n - k + 1, 165 Stride: n - k + 1, 166 Data: make([]float64, (n-k+1)*(n-k+1)), 167 } 168 for i := 0; i < qwant.Rows; i++ { 169 qwant.Data[i*qwant.Stride+i] = 1 170 } 171 blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant) 172 if !isOrthonormal(qwant) { 173 t.Errorf("%v: Q = I - V*T*V^T is not orthogonal", prefix) 174 } 175 176 // Compare Q and QWant. Note that since Q is 177 // (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we 178 // ignore the first row and column of QWant. 179 for i := 0; i < n-k; i++ { 180 for j := 0; j < n-k; j++ { 181 diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1]) 182 if diff > 1e-14 { 183 t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff) 184 } 185 } 186 } 187 } 188 } 189 } 190 191 // Go runs tests from the source directory, so unfortunately we need to 192 // include the "../testlapack" part. 193 file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlahr2data.json.gz")) 194 if err != nil { 195 log.Fatal(err) 196 } 197 defer file.Close() 198 r, err := gzip.NewReader(file) 199 if err != nil { 200 log.Fatal(err) 201 } 202 defer r.Close() 203 204 var tests []Dlahr2test 205 json.NewDecoder(r).Decode(&tests) 206 for _, test := range tests { 207 tau := make([]float64, len(test.TauWant)) 208 for _, ldex := range []int{0, 1, 20} { 209 n := test.N 210 k := test.K 211 nb := test.NB 212 213 lda := n - k + 1 + ldex 214 a := make([]float64, (n-1)*lda+n-k+1) 215 copyMatrix(n, n-k+1, a, lda, test.A) 216 217 ldt := nb + ldex 218 tmat := make([]float64, (nb-1)*ldt+nb) 219 220 ldy := nb + ldex 221 y := make([]float64, (n-1)*ldy+nb) 222 223 impl.Dlahr2(n, k, nb, a, lda, tau, tmat, ldt, y, ldy) 224 225 prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, ldex) 226 if !equalApprox(n, n-k+1, a, lda, test.AWant, 1e-14) { 227 t.Errorf("%v: unexpected matrix A\n got=%v\nwant=%v", prefix, a, test.AWant) 228 } 229 if !equalApproxTriangular(true, nb, tmat, ldt, test.TWant, 1e-14) { 230 t.Errorf("%v: unexpected matrix T\n got=%v\nwant=%v", prefix, tmat, test.TWant) 231 } 232 if !equalApprox(n, nb, y, ldy, test.YWant, 1e-14) { 233 t.Errorf("%v: unexpected matrix Y\n got=%v\nwant=%v", prefix, y, test.YWant) 234 } 235 if !floats.EqualApprox(tau, test.TauWant, 1e-14) { 236 t.Errorf("%v: unexpected slice tau\n got=%v\nwant=%v", prefix, tau, test.TauWant) 237 } 238 } 239 } 240 }