gonum.org/v1/gonum@v0.14.0/mat/svd_test.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/floats" 13 ) 14 15 func TestSVD(t *testing.T) { 16 t.Parallel() 17 rnd := rand.New(rand.NewSource(1)) 18 // Hand coded tests 19 for _, test := range []struct { 20 a *Dense 21 u *Dense 22 v *Dense 23 s []float64 24 }{ 25 { 26 a: NewDense(4, 2, []float64{2, 4, 1, 3, 0, 0, 0, 0}), 27 u: NewDense(4, 2, []float64{ 28 -0.8174155604703632, -0.5760484367663209, 29 -0.5760484367663209, 0.8174155604703633, 30 0, 0, 31 0, 0, 32 }), 33 v: NewDense(2, 2, []float64{ 34 -0.4045535848337571, -0.9145142956773044, 35 -0.9145142956773044, 0.4045535848337571, 36 }), 37 s: []float64{5.464985704219041, 0.365966190626258}, 38 }, 39 { 40 // Issue #5. 41 a: NewDense(3, 11, []float64{ 42 1, 1, 0, 1, 0, 0, 0, 0, 0, 11, 1, 43 1, 0, 0, 0, 0, 0, 1, 0, 0, 12, 2, 44 1, 1, 0, 0, 0, 0, 0, 0, 1, 13, 3, 45 }), 46 u: NewDense(3, 3, []float64{ 47 -0.5224167862273765, 0.7864430360363114, 0.3295270133658976, 48 -0.5739526766688285, -0.03852203026050301, -0.8179818935216693, 49 -0.6306021141833781, -0.6164603833618163, 0.4715056408282468, 50 }), 51 v: NewDense(11, 3, []float64{ 52 -0.08123293141915189, 0.08528085505260324, -0.013165501690885152, 53 -0.05423546426886932, 0.1102707844980355, 0.622210623111631, 54 0, 0, 0, 55 -0.0245733326078166, 0.510179651760153, 0.25596360803140994, 56 0, 0, 0, 57 0, 0, 0, 58 -0.026997467150282436, -0.024989929445430496, -0.6353761248025164, 59 0, 0, 0, 60 -0.029662131661052707, -0.3999088672621176, 0.3662470150802212, 61 -0.9798839760830571, 0.11328174160898856, -0.047702613241813366, 62 -0.16755466189153964, -0.7395268089170608, 0.08395240366704032, 63 }), 64 s: []float64{21.259500881097434, 1.5415021616856566, 1.2873979074613628}, 65 }, 66 } { 67 var svd SVD 68 ok := svd.Factorize(test.a, SVDThin) 69 if !ok { 70 t.Errorf("SVD failed") 71 } 72 s, u, v := extractSVD(&svd) 73 if !floats.EqualApprox(s, test.s, 1e-10) { 74 t.Errorf("Singular value mismatch. Got %v, want %v.", s, test.s) 75 } 76 if !EqualApprox(u, test.u, 1e-10) { 77 t.Errorf("U mismatch.\nGot:\n%v\nWant:\n%v", Formatted(u), Formatted(test.u)) 78 } 79 if !EqualApprox(v, test.v, 1e-10) { 80 t.Errorf("V mismatch.\nGot:\n%v\nWant:\n%v", Formatted(v), Formatted(test.v)) 81 } 82 m, n := test.a.Dims() 83 sigma := NewDense(min(m, n), min(m, n), nil) 84 for i := 0; i < min(m, n); i++ { 85 sigma.Set(i, i, s[i]) 86 } 87 88 var ans Dense 89 ans.Product(u, sigma, v.T()) 90 if !EqualApprox(test.a, &ans, 1e-10) { 91 t.Errorf("A reconstruction mismatch.\nGot:\n%v\nWant:\n%v\n", Formatted(&ans), Formatted(test.a)) 92 } 93 94 for _, kind := range []SVDKind{ 95 SVDThinU, SVDFullU, SVDThinV, SVDFullV, 96 } { 97 var svd SVD 98 svd.Factorize(test.a, kind) 99 if kind&SVDThinU == 0 && kind&SVDFullU == 0 { 100 panicked, message := panics(func() { 101 var dst Dense 102 svd.UTo(&dst) 103 }) 104 if !panicked { 105 t.Error("expected panic with no U matrix requested") 106 continue 107 } 108 want := "svd: u not computed during factorization" 109 if message != want { 110 t.Errorf("unexpected message: got:%q want:%q", message, want) 111 } 112 } 113 if kind&SVDThinV == 0 && kind&SVDFullV == 0 { 114 panicked, message := panics(func() { 115 var dst Dense 116 svd.VTo(&dst) 117 }) 118 if !panicked { 119 t.Error("expected panic with no V matrix requested") 120 continue 121 } 122 want := "svd: v not computed during factorization" 123 if message != want { 124 t.Errorf("unexpected message: got:%q want:%q", message, want) 125 } 126 } 127 } 128 } 129 130 for _, test := range []struct { 131 m, n int 132 }{ 133 {5, 5}, 134 {5, 3}, 135 {3, 5}, 136 {150, 150}, 137 {200, 150}, 138 {150, 200}, 139 } { 140 m := test.m 141 n := test.n 142 for trial := 0; trial < 10; trial++ { 143 a := NewDense(m, n, nil) 144 for i := range a.mat.Data { 145 a.mat.Data[i] = rnd.NormFloat64() 146 } 147 aCopy := DenseCopyOf(a) 148 149 // Test Full decomposition. 150 var svd SVD 151 ok := svd.Factorize(a, SVDFull) 152 if !ok { 153 t.Errorf("SVD factorization failed") 154 } 155 if !Equal(a, aCopy) { 156 t.Errorf("A changed during call to SVD with full") 157 } 158 s, u, v := extractSVD(&svd) 159 sigma := NewDense(m, n, nil) 160 for i := 0; i < min(m, n); i++ { 161 sigma.Set(i, i, s[i]) 162 } 163 var ansFull Dense 164 ansFull.Product(u, sigma, v.T()) 165 if !EqualApprox(&ansFull, a, 1e-8) { 166 t.Errorf("Answer mismatch when SVDFull") 167 } 168 169 // Test Thin decomposition. 170 ok = svd.Factorize(a, SVDThin) 171 if !ok { 172 t.Errorf("SVD factorization failed") 173 } 174 if !Equal(a, aCopy) { 175 t.Errorf("A changed during call to SVD with Thin") 176 } 177 sThin, u, v := extractSVD(&svd) 178 if !floats.EqualApprox(s, sThin, 1e-8) { 179 t.Errorf("Singular value mismatch between Full and Thin decomposition") 180 } 181 sigma = NewDense(min(m, n), min(m, n), nil) 182 for i := 0; i < min(m, n); i++ { 183 sigma.Set(i, i, sThin[i]) 184 } 185 ansFull.Reset() 186 ansFull.Product(u, sigma, v.T()) 187 if !EqualApprox(&ansFull, a, 1e-8) { 188 t.Errorf("Answer mismatch when SVDFull") 189 } 190 191 // Test None decomposition. 192 ok = svd.Factorize(a, SVDNone) 193 if !ok { 194 t.Errorf("SVD factorization failed") 195 } 196 if !Equal(a, aCopy) { 197 t.Errorf("A changed during call to SVD with none") 198 } 199 sNone := make([]float64, min(m, n)) 200 svd.Values(sNone) 201 if !floats.EqualApprox(s, sNone, 1e-8) { 202 t.Errorf("Singular value mismatch between Full and None decomposition") 203 } 204 } 205 } 206 } 207 208 func extractSVD(svd *SVD) (s []float64, u, v *Dense) { 209 u = &Dense{} 210 svd.UTo(u) 211 v = &Dense{} 212 svd.VTo(v) 213 return svd.Values(nil), u, v 214 } 215 216 func TestSVDSolveTo(t *testing.T) { 217 t.Parallel() 218 rnd := rand.New(rand.NewSource(1)) 219 // Hand-coded cases. 220 for i, test := range []struct { 221 a []float64 222 m, n int 223 b []float64 224 bc int 225 rcond float64 226 want []float64 227 wm, wn int 228 }{ 229 { 230 a: []float64{6}, m: 1, n: 1, 231 b: []float64{3}, bc: 1, 232 want: []float64{0.5}, wm: 1, wn: 1, 233 }, 234 { 235 a: []float64{ 236 1, 0, 0, 237 0, 1, 0, 238 0, 0, 1, 239 }, m: 3, n: 3, 240 b: []float64{ 241 3, 242 2, 243 1, 244 }, bc: 1, 245 want: []float64{ 246 3, 247 2, 248 1, 249 }, wm: 3, wn: 1, 250 }, 251 { 252 a: []float64{ 253 0.8147, 0.9134, 0.5528, 254 0.9058, 0.6324, 0.8723, 255 0.1270, 0.0975, 0.7612, 256 }, m: 3, n: 3, 257 b: []float64{ 258 0.278, 259 0.547, 260 0.958, 261 }, bc: 1, 262 want: []float64{ 263 -0.932687281002860, 264 0.303963920182067, 265 1.375216503507109, 266 }, wm: 3, wn: 1, 267 }, 268 { 269 a: []float64{ 270 0.8147, 0.9134, 0.5528, 271 0.9058, 0.6324, 0.8723, 272 }, m: 2, n: 3, 273 b: []float64{ 274 0.278, 275 0.547, 276 }, bc: 1, 277 want: []float64{ 278 0.25919787248965376, 279 -0.25560256266441034, 280 0.5432324059702451, 281 }, wm: 3, wn: 1, 282 }, 283 { 284 a: []float64{ 285 0.8147, 0.9134, 0.9, 286 0.9058, 0.6324, 0.9, 287 0.1270, 0.0975, 0.1, 288 1.6, 2.8, -3.5, 289 }, m: 4, n: 3, 290 b: []float64{ 291 0.278, 292 0.547, 293 -0.958, 294 1.452, 295 }, bc: 1, 296 want: []float64{ 297 0.820970340787782, 298 -0.218604626527306, 299 -0.212938815234215, 300 }, wm: 3, wn: 1, 301 }, 302 { 303 a: []float64{ 304 0.8147, 0.9134, 0.231, -1.65, 305 0.9058, 0.6324, 0.9, 0.72, 306 0.1270, 0.0975, 0.1, 1.723, 307 1.6, 2.8, -3.5, 0.987, 308 7.231, 9.154, 1.823, 0.9, 309 }, m: 5, n: 4, 310 b: []float64{ 311 0.278, 8.635, 312 0.547, 9.125, 313 -0.958, -0.762, 314 1.452, 1.444, 315 1.999, -7.234, 316 }, bc: 2, 317 want: []float64{ 318 1.863006789511373, 44.467887791812750, 319 -1.127270935407224, -34.073794226035126, 320 -0.527926457947330, -8.032133759788573, 321 -0.248621916204897, -2.366366415805275, 322 }, wm: 4, wn: 2, 323 }, 324 { 325 // Test rank-deficient case compared with numpy. 326 // >>> import numpy as np 327 // >>> b = np.array([[-2.3181340317357653], 328 // ... [-0.7146777651358073], 329 // ... [1.8361340927945298], 330 // ... [-0.35699930593018775], 331 // ... [-1.6359508076249094]]) 332 // >>> A = np.array([[-1.7854591879711257, -0.42687285925779594, -0.12730256811265162], 333 // ... [-0.5728984211439724, -0.10093393134001777, -0.1181901192353067], 334 // ... [1.2484316018707418, 0.5646683943038734, -0.48229492403243485], 335 // ... [0.10174927665169475, -0.5805410929482445, 1.3054473231942054], 336 // ... [-1.134174808195733, -0.4732430202414438, 0.3528489486370508]]) 337 // >>> np.linalg.lstsq(A, b, rcond=None) 338 // (array([[ 1.21208422], 339 // [ 0.41541503], 340 // [-0.18320349]]), array([], dtype=float64), 2, array([2.68451480e+00, 1.52593185e+00, 6.82840229e-17])) 341 342 a: []float64{ 343 -1.7854591879711257, -0.42687285925779594, -0.12730256811265162, 344 -0.5728984211439724, -0.10093393134001777, -0.1181901192353067, 345 1.2484316018707418, 0.5646683943038734, -0.48229492403243485, 346 0.10174927665169475, -0.5805410929482445, 1.3054473231942054, 347 -1.134174808195733, -0.4732430202414438, 0.3528489486370508, 348 }, m: 5, n: 3, 349 b: []float64{ 350 -2.3181340317357653, 351 -0.7146777651358073, 352 1.8361340927945298, 353 -0.35699930593018775, 354 -1.6359508076249094, 355 }, bc: 1, 356 rcond: 1e-15, 357 want: []float64{ 358 1.2120842180372118, 359 0.4154150318658529, 360 -0.1832034870198265, 361 }, wm: 3, wn: 1, 362 }, 363 { 364 a: []float64{ 365 0, 0, 366 0, 0, 367 }, m: 2, n: 2, 368 b: []float64{ 369 3, 370 2, 371 }, bc: 1, 372 }, 373 { 374 a: []float64{ 375 0, 0, 376 0, 0, 377 0, 0, 378 }, m: 3, n: 2, 379 b: []float64{ 380 3, 381 2, 382 1, 383 }, bc: 1, 384 }, 385 { 386 a: []float64{ 387 0, 0, 0, 388 0, 0, 0, 389 }, m: 2, n: 3, 390 b: []float64{ 391 3, 392 2, 393 }, bc: 1, 394 }, 395 } { 396 a := NewDense(test.m, test.n, test.a) 397 b := NewDense(test.m, test.bc, test.b) 398 399 var want *Dense 400 if test.want != nil { 401 want = NewDense(test.wm, test.wn, test.want) 402 } 403 404 var svd SVD 405 ok := svd.Factorize(a, SVDFull) 406 if !ok { 407 t.Errorf("unexpected factorization failure for test %d", i) 408 continue 409 } 410 411 var x Dense 412 rank := svd.Rank(test.rcond) 413 if rank == 0 { 414 continue 415 } 416 svd.SolveTo(&x, b, rank) 417 if !EqualApprox(&x, want, 1e-12) { 418 t.Errorf("Solve answer mismatch. Want %v, got %v", want, x) 419 } 420 } 421 422 // Random Cases. 423 for i, test := range []struct { 424 m, n, bc int 425 rcond float64 426 }{ 427 {m: 5, n: 5, bc: 1}, 428 {m: 5, n: 10, bc: 1}, 429 {m: 10, n: 5, bc: 1}, 430 {m: 5, n: 5, bc: 7}, 431 {m: 5, n: 10, bc: 7}, 432 {m: 10, n: 5, bc: 7}, 433 {m: 5, n: 5, bc: 12}, 434 {m: 5, n: 10, bc: 12}, 435 {m: 10, n: 5, bc: 12}, 436 } { 437 m := test.m 438 n := test.n 439 bc := test.bc 440 a := NewDense(m, n, nil) 441 for i := 0; i < m; i++ { 442 for j := 0; j < n; j++ { 443 a.Set(i, j, rnd.Float64()) 444 } 445 } 446 br := m 447 b := NewDense(br, bc, nil) 448 for i := 0; i < br; i++ { 449 for j := 0; j < bc; j++ { 450 b.Set(i, j, rnd.Float64()) 451 } 452 } 453 454 var svd SVD 455 ok := svd.Factorize(a, SVDFull) 456 if !ok { 457 t.Errorf("unexpected factorization failure for test %d", i) 458 continue 459 } 460 461 var x Dense 462 rank := svd.Rank(test.rcond) 463 if rank == 0 { 464 continue 465 } 466 svd.SolveTo(&x, b, rank) 467 468 // Test that the normal equations hold. 469 // Aᵀ * A * x = Aᵀ * b 470 var tmp, lhs, rhs Dense 471 tmp.Mul(a.T(), a) 472 lhs.Mul(&tmp, &x) 473 rhs.Mul(a.T(), b) 474 if !EqualApprox(&lhs, &rhs, 1e-10) { 475 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) 476 } 477 } 478 } 479 480 func TestSVDSolveVecTo(t *testing.T) { 481 t.Parallel() 482 rnd := rand.New(rand.NewSource(1)) 483 // Hand-coded cases. 484 for i, test := range []struct { 485 a []float64 486 m, n int 487 b []float64 488 rcond float64 489 want []float64 490 }{ 491 { 492 a: []float64{6}, m: 1, n: 1, 493 b: []float64{3}, 494 want: []float64{0.5}, 495 }, 496 { 497 a: []float64{ 498 1, 0, 0, 499 0, 1, 0, 500 0, 0, 1, 501 }, m: 3, n: 3, 502 b: []float64{3, 2, 1}, 503 want: []float64{3, 2, 1}, 504 }, 505 { 506 a: []float64{ 507 0.8147, 0.9134, 0.5528, 508 0.9058, 0.6324, 0.8723, 509 0.1270, 0.0975, 0.7612, 510 }, m: 3, n: 3, 511 b: []float64{0.278, 0.547, 0.958}, 512 want: []float64{-0.932687281002860, 0.303963920182067, 1.375216503507109}, 513 }, 514 { 515 a: []float64{ 516 0.8147, 0.9134, 0.5528, 517 0.9058, 0.6324, 0.8723, 518 }, m: 2, n: 3, 519 b: []float64{0.278, 0.547}, 520 want: []float64{0.25919787248965376, -0.25560256266441034, 0.5432324059702451}, 521 }, 522 { 523 a: []float64{ 524 0.8147, 0.9134, 0.9, 525 0.9058, 0.6324, 0.9, 526 0.1270, 0.0975, 0.1, 527 1.6, 2.8, -3.5, 528 }, m: 4, n: 3, 529 b: []float64{0.278, 0.547, -0.958, 1.452}, 530 want: []float64{0.820970340787782, -0.218604626527306, -0.212938815234215}, 531 }, 532 { 533 // Test rank-deficient case compared with numpy. 534 // >>> import numpy as np 535 // >>> b = np.array([[-2.3181340317357653], 536 // ... [-0.7146777651358073], 537 // ... [1.8361340927945298], 538 // ... [-0.35699930593018775], 539 // ... [-1.6359508076249094]]) 540 // >>> A = np.array([[-1.7854591879711257, -0.42687285925779594, -0.12730256811265162], 541 // ... [-0.5728984211439724, -0.10093393134001777, -0.1181901192353067], 542 // ... [1.2484316018707418, 0.5646683943038734, -0.48229492403243485], 543 // ... [0.10174927665169475, -0.5805410929482445, 1.3054473231942054], 544 // ... [-1.134174808195733, -0.4732430202414438, 0.3528489486370508]]) 545 // >>> np.linalg.lstsq(A, b, rcond=None) 546 // (array([[ 1.21208422], 547 // [ 0.41541503], 548 // [-0.18320349]]), array([], dtype=float64), 2, array([2.68451480e+00, 1.52593185e+00, 6.82840229e-17])) 549 550 a: []float64{ 551 -1.7854591879711257, -0.42687285925779594, -0.12730256811265162, 552 -0.5728984211439724, -0.10093393134001777, -0.1181901192353067, 553 1.2484316018707418, 0.5646683943038734, -0.48229492403243485, 554 0.10174927665169475, -0.5805410929482445, 1.3054473231942054, 555 -1.134174808195733, -0.4732430202414438, 0.3528489486370508, 556 }, m: 5, n: 3, 557 b: []float64{-2.3181340317357653, -0.7146777651358073, 1.8361340927945298, -0.35699930593018775, -1.6359508076249094}, 558 rcond: 1e-15, 559 want: []float64{1.2120842180372118, 0.4154150318658529, -0.1832034870198265}, 560 }, 561 { 562 a: []float64{ 563 0, 0, 564 0, 0, 565 }, m: 2, n: 2, 566 b: []float64{3, 2}, 567 }, 568 { 569 a: []float64{ 570 0, 0, 571 0, 0, 572 0, 0, 573 }, m: 3, n: 2, 574 b: []float64{3, 2, 1}, 575 }, 576 { 577 a: []float64{ 578 0, 0, 0, 579 0, 0, 0, 580 }, m: 2, n: 3, 581 b: []float64{3, 2}, 582 }, 583 } { 584 a := NewDense(test.m, test.n, test.a) 585 b := NewVecDense(len(test.b), test.b) 586 587 var want *VecDense 588 if test.want != nil { 589 want = NewVecDense(len(test.want), test.want) 590 } 591 592 var svd SVD 593 ok := svd.Factorize(a, SVDFull) 594 if !ok { 595 t.Errorf("unexpected factorization failure for test %d", i) 596 continue 597 } 598 599 var x VecDense 600 rank := svd.Rank(test.rcond) 601 if rank == 0 { 602 continue 603 } 604 svd.SolveVecTo(&x, b, rank) 605 if !EqualApprox(&x, want, 1e-12) { 606 t.Errorf("Solve answer mismatch. Want %v, got %v", want, x) 607 } 608 } 609 610 // Random Cases. 611 for i, test := range []struct { 612 m, n int 613 rcond float64 614 }{ 615 {m: 5, n: 5}, 616 {m: 5, n: 10}, 617 {m: 10, n: 5}, 618 {m: 5, n: 5}, 619 {m: 5, n: 10}, 620 {m: 10, n: 5}, 621 {m: 5, n: 5}, 622 {m: 5, n: 10}, 623 {m: 10, n: 5}, 624 } { 625 m := test.m 626 n := test.n 627 a := NewDense(m, n, nil) 628 for i := 0; i < m; i++ { 629 for j := 0; j < n; j++ { 630 a.Set(i, j, rnd.Float64()) 631 } 632 } 633 br := m 634 b := NewVecDense(br, nil) 635 for i := 0; i < br; i++ { 636 b.SetVec(i, rnd.Float64()) 637 } 638 639 var svd SVD 640 ok := svd.Factorize(a, SVDFull) 641 if !ok { 642 t.Errorf("unexpected factorization failure for test %d", i) 643 continue 644 } 645 646 var x VecDense 647 rank := svd.Rank(test.rcond) 648 if rank == 0 { 649 continue 650 } 651 svd.SolveVecTo(&x, b, rank) 652 653 // Test that the normal equations hold. 654 // Aᵀ * A * x = Aᵀ * b 655 var tmp, lhs, rhs Dense 656 tmp.Mul(a.T(), a) 657 lhs.Mul(&tmp, &x) 658 rhs.Mul(a.T(), b) 659 if !EqualApprox(&lhs, &rhs, 1e-10) { 660 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) 661 } 662 } 663 }