github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dbdsqr.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math/rand" 10 "sort" 11 "testing" 12 13 "github.com/gonum/blas" 14 "github.com/gonum/blas/blas64" 15 "github.com/gonum/floats" 16 ) 17 18 type Dbdsqrer interface { 19 Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) 20 } 21 22 func DbdsqrTest(t *testing.T, impl Dbdsqrer) { 23 rnd := rand.New(rand.NewSource(1)) 24 bi := blas64.Implementation() 25 _ = bi 26 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 27 for _, test := range []struct { 28 n, ncvt, nru, ncc, ldvt, ldu, ldc int 29 }{ 30 {5, 5, 5, 5, 0, 0, 0}, 31 {10, 10, 10, 10, 0, 0, 0}, 32 {10, 11, 12, 13, 0, 0, 0}, 33 {20, 13, 12, 11, 0, 0, 0}, 34 35 {5, 5, 5, 5, 6, 7, 8}, 36 {10, 10, 10, 10, 30, 40, 50}, 37 {10, 12, 11, 13, 30, 40, 50}, 38 {20, 12, 13, 11, 30, 40, 50}, 39 40 {130, 130, 130, 500, 900, 900, 500}, 41 } { 42 for cas := 0; cas < 100; cas++ { 43 n := test.n 44 ncvt := test.ncvt 45 nru := test.nru 46 ncc := test.ncc 47 ldvt := test.ldvt 48 ldu := test.ldu 49 ldc := test.ldc 50 if ldvt == 0 { 51 ldvt = ncvt 52 } 53 if ldu == 0 { 54 ldu = n 55 } 56 if ldc == 0 { 57 ldc = ncc 58 } 59 60 d := make([]float64, n) 61 for i := range d { 62 d[i] = rnd.NormFloat64() 63 } 64 e := make([]float64, n-1) 65 for i := range e { 66 e[i] = rnd.NormFloat64() 67 } 68 dCopy := make([]float64, len(d)) 69 copy(dCopy, d) 70 eCopy := make([]float64, len(e)) 71 copy(eCopy, e) 72 work := make([]float64, 4*n) 73 for i := range work { 74 work[i] = rnd.NormFloat64() 75 } 76 77 // First test the decomposition of the bidiagonal matrix. Set 78 // pt and u equal to I with the correct size. At the result 79 // of Dbdsqr, p and u will contain the data of P^T and Q, which 80 // will be used in the next step to test the multiplication 81 // with Q and VT. 82 83 q := make([]float64, n*n) 84 ldq := n 85 pt := make([]float64, n*n) 86 ldpt := n 87 for i := 0; i < n; i++ { 88 q[i*ldq+i] = 1 89 } 90 for i := 0; i < n; i++ { 91 pt[i*ldpt+i] = 1 92 } 93 94 ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 0, work) 95 96 isUpper := uplo == blas.Upper 97 errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc) 98 if !ok { 99 t.Errorf("Unexpected Dbdsqr failure: %s", errStr) 100 } 101 102 bMat := constructBidiagonal(uplo, n, dCopy, eCopy) 103 sMat := constructBidiagonal(uplo, n, d, e) 104 105 tmp := blas64.General{ 106 Rows: n, 107 Cols: n, 108 Stride: n, 109 Data: make([]float64, n*n), 110 } 111 ansMat := blas64.General{ 112 Rows: n, 113 Cols: n, 114 Stride: n, 115 Data: make([]float64, n*n), 116 } 117 118 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, q, ldq, sMat.Data, sMat.Stride, 0, tmp.Data, tmp.Stride) 119 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, tmp.Data, tmp.Stride, pt, ldpt, 0, ansMat.Data, ansMat.Stride) 120 121 same := true 122 for i := 0; i < n; i++ { 123 for j := 0; j < n; j++ { 124 if !floats.EqualWithinAbsOrRel(ansMat.Data[i*ansMat.Stride+j], bMat.Data[i*bMat.Stride+j], 1e-8, 1e-8) { 125 same = false 126 } 127 } 128 } 129 if !same { 130 t.Errorf("Bidiagonal mismatch. %s", errStr) 131 } 132 if !sort.IsSorted(sort.Reverse(sort.Float64Slice(d))) { 133 t.Errorf("D is not sorted. %s", errStr) 134 } 135 136 // The above computed the real P and Q. Now input data for V^T, 137 // U, and C to check that the multiplications happen properly. 138 dAns := make([]float64, len(d)) 139 copy(dAns, d) 140 eAns := make([]float64, len(e)) 141 copy(eAns, e) 142 143 u := make([]float64, nru*ldu) 144 for i := range u { 145 u[i] = rnd.NormFloat64() 146 } 147 uCopy := make([]float64, len(u)) 148 copy(uCopy, u) 149 vt := make([]float64, n*ldvt) 150 for i := range vt { 151 vt[i] = rnd.NormFloat64() 152 } 153 vtCopy := make([]float64, len(vt)) 154 copy(vtCopy, vt) 155 c := make([]float64, n*ldc) 156 for i := range c { 157 c[i] = rnd.NormFloat64() 158 } 159 cCopy := make([]float64, len(c)) 160 copy(cCopy, c) 161 162 // Reset input data 163 copy(d, dCopy) 164 copy(e, eCopy) 165 impl.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work) 166 167 // Check result. 168 if !floats.EqualApprox(d, dAns, 1e-14) { 169 t.Errorf("D mismatch second time. %s", errStr) 170 } 171 if !floats.EqualApprox(e, eAns, 1e-14) { 172 t.Errorf("E mismatch second time. %s", errStr) 173 } 174 ans := make([]float64, len(vtCopy)) 175 copy(ans, vtCopy) 176 ldans := ldvt 177 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, ncvt, n, 1, pt, ldpt, vtCopy, ldvt, 0, ans, ldans) 178 if !floats.EqualApprox(ans, vt, 1e-10) { 179 t.Errorf("Vt result mismatch. %s", errStr) 180 } 181 ans = make([]float64, len(uCopy)) 182 copy(ans, uCopy) 183 ldans = ldu 184 bi.Dgemm(blas.NoTrans, blas.NoTrans, nru, n, n, 1, uCopy, ldu, q, ldq, 0, ans, ldans) 185 if !floats.EqualApprox(ans, u, 1e-10) { 186 t.Errorf("U result mismatch. %s", errStr) 187 } 188 ans = make([]float64, len(cCopy)) 189 copy(ans, cCopy) 190 ldans = ldc 191 bi.Dgemm(blas.Trans, blas.NoTrans, n, ncc, n, 1, q, ldq, cCopy, ldc, 0, ans, ldans) 192 if !floats.EqualApprox(ans, c, 1e-10) { 193 t.Errorf("C result mismatch. %s", errStr) 194 } 195 } 196 } 197 } 198 }