github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dlahr2.go (about) 1 // Copyright ©2016 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "github.com/jingcheng-WU/gonum/blas" 15 "github.com/jingcheng-WU/gonum/blas/blas64" 16 ) 17 18 type Dlahr2er interface { 19 Dlahr2(n, k, nb int, a []float64, lda int, tau, t []float64, ldt int, y []float64, ldy int) 20 } 21 22 func Dlahr2Test(t *testing.T, impl Dlahr2er) { 23 const tol = 1e-14 24 25 rnd := rand.New(rand.NewSource(1)) 26 for _, test := range []struct { 27 n, k, nb int 28 }{ 29 {3, 0, 3}, 30 {3, 1, 2}, 31 {3, 1, 1}, 32 33 {5, 0, 5}, 34 {5, 1, 4}, 35 {5, 1, 3}, 36 {5, 1, 2}, 37 {5, 1, 1}, 38 {5, 2, 3}, 39 {5, 2, 2}, 40 {5, 2, 1}, 41 {5, 3, 2}, 42 {5, 3, 1}, 43 44 {7, 3, 4}, 45 {7, 3, 3}, 46 {7, 3, 2}, 47 {7, 3, 1}, 48 49 {10, 0, 10}, 50 {10, 1, 9}, 51 {10, 1, 5}, 52 {10, 1, 1}, 53 {10, 5, 5}, 54 {10, 5, 3}, 55 {10, 5, 1}, 56 } { 57 for cas := 0; cas < 100; cas++ { 58 for _, extraStride := range []int{0, 1, 10} { 59 n := test.n 60 k := test.k 61 nb := test.nb 62 63 a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd) 64 aCopy := a 65 aCopy.Data = make([]float64, len(a.Data)) 66 copy(aCopy.Data, a.Data) 67 tmat := nanTriangular(blas.Upper, nb, nb+extraStride) 68 y := nanGeneral(n, nb, nb+extraStride) 69 tau := nanSlice(nb) 70 71 impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride) 72 73 prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride) 74 75 if !generalOutsideAllNaN(a) { 76 t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data) 77 } 78 if !triangularOutsideAllNaN(tmat) { 79 t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data) 80 } 81 if !generalOutsideAllNaN(y) { 82 t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data) 83 } 84 85 // Check that A[:k,:] and A[:,nb:] blocks were not modified. 86 for i := 0; i < n; i++ { 87 for j := 0; j < n-k+1; j++ { 88 if i >= k && j < nb { 89 continue 90 } 91 if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] { 92 t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j) 93 } 94 } 95 } 96 97 // Check that all elements of tau were assigned. 98 for i, v := range tau { 99 if math.IsNaN(v) { 100 t.Errorf("%v: tau[%v] not assigned", prefix, i) 101 } 102 } 103 104 // Extract V from a. 105 v := blas64.General{ 106 Rows: n - k + 1, 107 Cols: nb, 108 Stride: nb, 109 Data: make([]float64, (n-k+1)*nb), 110 } 111 for j := 0; j < v.Cols; j++ { 112 v.Data[(j+1)*v.Stride+j] = 1 113 for i := j + 2; i < v.Rows; i++ { 114 v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j] 115 } 116 } 117 118 // VT = V. 119 vt := v 120 vt.Data = make([]float64, len(v.Data)) 121 copy(vt.Data, v.Data) 122 // VT = V * T. 123 blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt) 124 // YWant = A * V * T. 125 ywant := blas64.General{ 126 Rows: n, 127 Cols: nb, 128 Stride: nb, 129 Data: make([]float64, n*nb), 130 } 131 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant) 132 133 // Compare Y and YWant. 134 for i := 0; i < n; i++ { 135 for j := 0; j < nb; j++ { 136 diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j]) 137 if diff > tol { 138 t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff) 139 } 140 } 141 } 142 143 // Construct Q directly from the first nb columns of a. 144 q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau) 145 if resid := residualOrthogonal(q, false); resid > tol*float64(n) { 146 t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", prefix, resid, tol*float64(n)) 147 } 148 // Construct Q as the product Q = I - V*T*Vᵀ. 149 qwant := blas64.General{ 150 Rows: n - k + 1, 151 Cols: n - k + 1, 152 Stride: n - k + 1, 153 Data: make([]float64, (n-k+1)*(n-k+1)), 154 } 155 for i := 0; i < qwant.Rows; i++ { 156 qwant.Data[i*qwant.Stride+i] = 1 157 } 158 blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant) 159 if resid := residualOrthogonal(qwant, false); resid > tol*float64(n) { 160 t.Errorf("Case %v: Q = I - V*T*Vᵀ is not orthogonal; resid=%v, want<=%v", prefix, resid, tol*float64(n)) 161 } 162 163 // Compare Q and QWant. Note that since Q is 164 // (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we 165 // ignore the first row and column of QWant. 166 for i := 0; i < n-k; i++ { 167 for j := 0; j < n-k; j++ { 168 diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1]) 169 if diff > tol { 170 t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff) 171 } 172 } 173 } 174 } 175 } 176 } 177 }