github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dbdsqr.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "sort" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "github.com/jingcheng-WU/gonum/blas" 15 "github.com/jingcheng-WU/gonum/blas/blas64" 16 "github.com/jingcheng-WU/gonum/floats" 17 "github.com/jingcheng-WU/gonum/floats/scalar" 18 ) 19 20 type Dbdsqrer interface { 21 Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) 22 } 23 24 func DbdsqrTest(t *testing.T, impl Dbdsqrer) { 25 rnd := rand.New(rand.NewSource(1)) 26 bi := blas64.Implementation() 27 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 28 for _, test := range []struct { 29 n, ncvt, nru, ncc, ldvt, ldu, ldc int 30 }{ 31 {5, 5, 5, 5, 0, 0, 0}, 32 {10, 10, 10, 10, 0, 0, 0}, 33 {10, 11, 12, 13, 0, 0, 0}, 34 {20, 13, 12, 11, 0, 0, 0}, 35 36 {5, 5, 5, 5, 6, 7, 8}, 37 {10, 10, 10, 10, 30, 40, 50}, 38 {10, 12, 11, 13, 30, 40, 50}, 39 {20, 12, 13, 11, 30, 40, 50}, 40 41 {130, 130, 130, 500, 900, 900, 500}, 42 } { 43 for cas := 0; cas < 10; cas++ { 44 n := test.n 45 ncvt := test.ncvt 46 nru := test.nru 47 ncc := test.ncc 48 ldvt := test.ldvt 49 ldu := test.ldu 50 ldc := test.ldc 51 if ldvt == 0 { 52 ldvt = max(1, ncvt) 53 } 54 if ldu == 0 { 55 ldu = max(1, n) 56 } 57 if ldc == 0 { 58 ldc = max(1, ncc) 59 } 60 61 d := make([]float64, n) 62 for i := range d { 63 d[i] = rnd.NormFloat64() 64 } 65 e := make([]float64, n-1) 66 for i := range e { 67 e[i] = rnd.NormFloat64() 68 } 69 dCopy := make([]float64, len(d)) 70 copy(dCopy, d) 71 eCopy := make([]float64, len(e)) 72 copy(eCopy, e) 73 work := make([]float64, 4*(n-1)) 74 for i := range work { 75 work[i] = rnd.NormFloat64() 76 } 77 78 // First test the decomposition of the bidiagonal matrix. Set 79 // pt and u equal to I with the correct size. At the result 80 // of Dbdsqr, p and u will contain the data of Pᵀ and Q, which 81 // will be used in the next step to test the multiplication 82 // with Q and VT. 83 84 q := make([]float64, n*n) 85 ldq := n 86 pt := make([]float64, n*n) 87 ldpt := n 88 for i := 0; i < n; i++ { 89 q[i*ldq+i] = 1 90 } 91 for i := 0; i < n; i++ { 92 pt[i*ldpt+i] = 1 93 } 94 95 ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 1, work) 96 97 isUpper := uplo == blas.Upper 98 errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc) 99 if !ok { 100 t.Errorf("Unexpected Dbdsqr failure: %s", errStr) 101 } 102 103 bMat := constructBidiagonal(uplo, n, dCopy, eCopy) 104 sMat := constructBidiagonal(uplo, n, d, e) 105 106 tmp := blas64.General{ 107 Rows: n, 108 Cols: n, 109 Stride: n, 110 Data: make([]float64, n*n), 111 } 112 ansMat := blas64.General{ 113 Rows: n, 114 Cols: n, 115 Stride: n, 116 Data: make([]float64, n*n), 117 } 118 119 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, q, ldq, sMat.Data, sMat.Stride, 0, tmp.Data, tmp.Stride) 120 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, tmp.Data, tmp.Stride, pt, ldpt, 0, ansMat.Data, ansMat.Stride) 121 122 same := true 123 for i := 0; i < n; i++ { 124 for j := 0; j < n; j++ { 125 if !scalar.EqualWithinAbsOrRel(ansMat.Data[i*ansMat.Stride+j], bMat.Data[i*bMat.Stride+j], 1e-8, 1e-8) { 126 same = false 127 } 128 } 129 } 130 if !same { 131 t.Errorf("Bidiagonal mismatch. %s", errStr) 132 } 133 if !sort.IsSorted(sort.Reverse(sort.Float64Slice(d))) { 134 t.Errorf("D is not sorted. %s", errStr) 135 } 136 137 // The above computed the real P and Q. Now input data for Vᵀ, 138 // U, and C to check that the multiplications happen properly. 139 dAns := make([]float64, len(d)) 140 copy(dAns, d) 141 eAns := make([]float64, len(e)) 142 copy(eAns, e) 143 144 u := make([]float64, nru*ldu) 145 for i := range u { 146 u[i] = rnd.NormFloat64() 147 } 148 uCopy := make([]float64, len(u)) 149 copy(uCopy, u) 150 vt := make([]float64, n*ldvt) 151 for i := range vt { 152 vt[i] = rnd.NormFloat64() 153 } 154 vtCopy := make([]float64, len(vt)) 155 copy(vtCopy, vt) 156 c := make([]float64, n*ldc) 157 for i := range c { 158 c[i] = rnd.NormFloat64() 159 } 160 cCopy := make([]float64, len(c)) 161 copy(cCopy, c) 162 163 // Reset input data 164 copy(d, dCopy) 165 copy(e, eCopy) 166 impl.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work) 167 168 // Check result. 169 if !floats.EqualApprox(d, dAns, 1e-14) { 170 t.Errorf("D mismatch second time. %s", errStr) 171 } 172 if !floats.EqualApprox(e, eAns, 1e-14) { 173 t.Errorf("E mismatch second time. %s", errStr) 174 } 175 ans := make([]float64, len(vtCopy)) 176 copy(ans, vtCopy) 177 ldans := ldvt 178 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, ncvt, n, 1, pt, ldpt, vtCopy, ldvt, 0, ans, ldans) 179 if !floats.EqualApprox(ans, vt, 1e-10) { 180 t.Errorf("Vt result mismatch. %s", errStr) 181 } 182 ans = make([]float64, len(uCopy)) 183 copy(ans, uCopy) 184 ldans = ldu 185 bi.Dgemm(blas.NoTrans, blas.NoTrans, nru, n, n, 1, uCopy, ldu, q, ldq, 0, ans, ldans) 186 if !floats.EqualApprox(ans, u, 1e-10) { 187 t.Errorf("U result mismatch. %s", errStr) 188 } 189 ans = make([]float64, len(cCopy)) 190 copy(ans, cCopy) 191 ldans = ldc 192 bi.Dgemm(blas.Trans, blas.NoTrans, n, ncc, n, 1, q, ldq, cCopy, ldc, 0, ans, ldans) 193 if !floats.EqualApprox(ans, c, 1e-10) { 194 t.Errorf("C result mismatch. %s", errStr) 195 } 196 } 197 } 198 } 199 }