github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgesvd.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 "github.com/gonum/floats" 15 "github.com/gonum/lapack" 16 ) 17 18 type Dgesvder interface { 19 Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool) 20 } 21 22 func DgesvdTest(t *testing.T, impl Dgesvder) { 23 rnd := rand.New(rand.NewSource(1)) 24 // TODO(btracey): Add tests for all of the cases when the SVD implementation 25 // is finished. 26 // TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD 27 // conditions are implemented. Right now mnthr is 5,000,000 which is too 28 // large to create a square matrix of that size. 29 for _, test := range []struct { 30 m, n, lda, ldu, ldvt int 31 }{ 32 {5, 5, 0, 0, 0}, 33 {5, 6, 0, 0, 0}, 34 {6, 5, 0, 0, 0}, 35 {5, 9, 0, 0, 0}, 36 {9, 5, 0, 0, 0}, 37 38 {5, 5, 10, 11, 12}, 39 {5, 6, 10, 11, 12}, 40 {6, 5, 10, 11, 12}, 41 {5, 5, 10, 11, 12}, 42 {5, 9, 10, 11, 12}, 43 {9, 5, 10, 11, 12}, 44 45 {300, 300, 0, 0, 0}, 46 {300, 400, 0, 0, 0}, 47 {400, 300, 0, 0, 0}, 48 {300, 600, 0, 0, 0}, 49 {600, 300, 0, 0, 0}, 50 51 {300, 300, 400, 450, 460}, 52 {300, 400, 500, 550, 560}, 53 {400, 300, 550, 550, 560}, 54 {300, 600, 700, 750, 760}, 55 {600, 300, 700, 750, 760}, 56 } { 57 jobU := lapack.SVDAll 58 jobVT := lapack.SVDAll 59 60 m := test.m 61 n := test.n 62 lda := test.lda 63 if lda == 0 { 64 lda = n 65 } 66 ldu := test.ldu 67 if ldu == 0 { 68 ldu = m 69 } 70 ldvt := test.ldvt 71 if ldvt == 0 { 72 ldvt = n 73 } 74 75 a := make([]float64, m*lda) 76 for i := range a { 77 a[i] = rnd.NormFloat64() 78 } 79 80 u := make([]float64, m*ldu) 81 for i := range u { 82 u[i] = rnd.NormFloat64() 83 } 84 85 vt := make([]float64, n*ldvt) 86 for i := range vt { 87 vt[i] = rnd.NormFloat64() 88 } 89 90 uAllOrig := make([]float64, len(u)) 91 copy(uAllOrig, u) 92 vtAllOrig := make([]float64, len(vt)) 93 copy(vtAllOrig, vt) 94 aCopy := make([]float64, len(a)) 95 copy(aCopy, a) 96 97 s := make([]float64, min(m, n)) 98 99 work := make([]float64, 1) 100 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1) 101 102 if !floats.Equal(a, aCopy) { 103 t.Errorf("a changed during call to get work length") 104 } 105 106 work = make([]float64, int(work[0])) 107 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work)) 108 109 errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt) 110 svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda) 111 svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false) 112 113 // Test InPlace 114 jobU = lapack.SVDInPlace 115 jobVT = lapack.SVDInPlace 116 copy(a, aCopy) 117 copy(u, uAllOrig) 118 copy(vt, vtAllOrig) 119 120 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work)) 121 svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda) 122 svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false) 123 } 124 } 125 126 // svdCheckPartial checks that the singular values and vectors are computed when 127 // not all of them are computed. 128 func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) { 129 rnd := rand.New(rand.NewSource(1)) 130 jobU := job 131 jobVT := job 132 // Compare the singular values when computed with {SVDNone, SVDNone.} 133 sCopy := make([]float64, len(s)) 134 copy(sCopy, s) 135 copy(a, aCopy) 136 for i := range s { 137 s[i] = rnd.Float64() 138 } 139 tmp1 := make([]float64, 1) 140 tmp2 := make([]float64, 1) 141 jobU = lapack.SVDNone 142 jobVT = lapack.SVDNone 143 144 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1) 145 work = make([]float64, int(work[0])) 146 lwork := len(work) 147 if shortWork { 148 lwork-- 149 } 150 ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork) 151 if !ok { 152 t.Errorf("Dgesvd did not complete successfully") 153 } 154 if !floats.EqualApprox(s, sCopy, 1e-10) { 155 t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr) 156 } 157 // Check that the singular vectors are correctly computed when the other 158 // is none. 159 uAll := make([]float64, len(u)) 160 copy(uAll, u) 161 vtAll := make([]float64, len(vt)) 162 copy(vtAll, vt) 163 164 // Copy the original vectors so the data outside the matrix bounds is the same. 165 copy(u, uAllOrig) 166 copy(vt, vtAllOrig) 167 168 jobU = job 169 jobVT = lapack.SVDNone 170 copy(a, aCopy) 171 for i := range s { 172 s[i] = rnd.Float64() 173 } 174 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1) 175 work = make([]float64, int(work[0])) 176 lwork = len(work) 177 if shortWork { 178 lwork-- 179 } 180 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work)) 181 if !floats.EqualApprox(uAll, u, 1e-10) { 182 t.Errorf("U mismatch when VT is not computed: %s", errStr) 183 } 184 if !floats.EqualApprox(s, sCopy, 1e-10) { 185 t.Errorf("Singular value mismatch when U computed VT not") 186 } 187 jobU = lapack.SVDNone 188 jobVT = job 189 copy(a, aCopy) 190 for i := range s { 191 s[i] = rnd.Float64() 192 } 193 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1) 194 work = make([]float64, int(work[0])) 195 lwork = len(work) 196 if shortWork { 197 lwork-- 198 } 199 impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work)) 200 if !floats.EqualApprox(vtAll, vt, 1e-10) { 201 t.Errorf("VT mismatch when U is not computed: %s", errStr) 202 } 203 if !floats.EqualApprox(s, sCopy, 1e-10) { 204 t.Errorf("Singular value mismatch when VT computed U not") 205 } 206 } 207 208 // svdCheck checks that the singular value decomposition correctly multiplies back 209 // to the original matrix. 210 func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) { 211 sigma := blas64.General{ 212 Rows: m, 213 Cols: n, 214 Stride: n, 215 Data: make([]float64, m*n), 216 } 217 for i := 0; i < min(m, n); i++ { 218 sigma.Data[i*sigma.Stride+i] = s[i] 219 } 220 221 uMat := blas64.General{ 222 Rows: m, 223 Cols: m, 224 Stride: ldu, 225 Data: u, 226 } 227 vTMat := blas64.General{ 228 Rows: n, 229 Cols: n, 230 Stride: ldvt, 231 Data: vt, 232 } 233 if thin { 234 sigma.Rows = min(m, n) 235 sigma.Cols = min(m, n) 236 uMat.Cols = min(m, n) 237 vTMat.Rows = min(m, n) 238 } 239 240 tmp := blas64.General{ 241 Rows: m, 242 Cols: n, 243 Stride: n, 244 Data: make([]float64, m*n), 245 } 246 ans := blas64.General{ 247 Rows: m, 248 Cols: n, 249 Stride: lda, 250 Data: make([]float64, m*lda), 251 } 252 copy(ans.Data, a) 253 254 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp) 255 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans) 256 257 if !floats.EqualApprox(ans.Data, aCopy, 1e-8) { 258 t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr) 259 } 260 261 if !thin { 262 // Check that U and V are orthogonal. 263 for i := 0; i < uMat.Rows; i++ { 264 for j := i + 1; j < uMat.Rows; j++ { 265 dot := blas64.Dot(uMat.Cols, 266 blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]}, 267 blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]}, 268 ) 269 if dot > 1e-8 { 270 t.Errorf("U not orthogonal %s", errStr) 271 } 272 } 273 } 274 for i := 0; i < vTMat.Rows; i++ { 275 for j := i + 1; j < vTMat.Rows; j++ { 276 dot := blas64.Dot(vTMat.Cols, 277 blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]}, 278 blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]}, 279 ) 280 if dot > 1e-8 { 281 t.Errorf("V not orthogonal %s", errStr) 282 } 283 } 284 } 285 } 286 }