gonum.org/v1/gonum@v0.14.0/mat/gsvd_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/floats" 14 "gonum.org/v1/gonum/floats/scalar" 15 ) 16 17 func TestGSVD(t *testing.T) { 18 t.Parallel() 19 20 const tol = 1e-10 21 for _, test := range []struct { 22 m, p, n int 23 }{ 24 {5, 3, 5}, 25 {5, 3, 3}, 26 {3, 3, 5}, 27 {5, 5, 5}, 28 {5, 5, 3}, 29 {3, 5, 5}, 30 {150, 150, 150}, 31 {200, 150, 150}, 32 {150, 150, 200}, 33 {150, 200, 150}, 34 {200, 200, 150}, 35 {150, 200, 200}, 36 } { 37 m := test.m 38 p := test.p 39 n := test.n 40 t.Run(fmt.Sprintf("%v", test), func(t *testing.T) { 41 t.Parallel() 42 43 rnd := rand.New(rand.NewSource(1)) 44 for trial := 0; trial < 10; trial++ { 45 a := NewDense(m, n, nil) 46 for i := range a.mat.Data { 47 a.mat.Data[i] = rnd.NormFloat64() 48 } 49 aCopy := DenseCopyOf(a) 50 51 b := NewDense(p, n, nil) 52 for i := range b.mat.Data { 53 b.mat.Data[i] = rnd.NormFloat64() 54 } 55 bCopy := DenseCopyOf(b) 56 57 // Test Full decomposition. 58 var gsvd GSVD 59 ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ) 60 if !ok { 61 t.Errorf("GSVD factorization failed") 62 } 63 if !Equal(a, aCopy) { 64 t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ") 65 } 66 if !Equal(b, bCopy) { 67 t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ") 68 } 69 c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd) 70 var ansU, ansV, d1R, d2R Dense 71 ansU.Product(u.T(), a, q) 72 ansV.Product(v.T(), b, q) 73 d1R.Mul(sigma1, zeroR) 74 d2R.Mul(sigma2, zeroR) 75 if !EqualApprox(&ansU, &d1R, tol) { 76 t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nUᵀ * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f", 77 Formatted(&ansU), Formatted(&d1R)) 78 } 79 if !EqualApprox(&ansV, &d2R, tol) { 80 t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nVᵀ * B *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f", 81 Formatted(&d2R), Formatted(&ansV)) 82 } 83 84 // Check C^2 + S^2 = I. 85 for i := range c { 86 d := c[i]*c[i] + s[i]*s[i] 87 if !scalar.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) { 88 t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d) 89 } 90 } 91 92 // Test None decomposition. 93 ok = gsvd.Factorize(a, b, GSVDNone) 94 if !ok { 95 t.Errorf("GSVD factorization failed") 96 } 97 if !Equal(a, aCopy) { 98 t.Errorf("A changed during call to GSVD with GSVDNone") 99 } 100 if !Equal(b, bCopy) { 101 t.Errorf("B changed during call to GSVD with GSVDNone") 102 } 103 cNone := gsvd.ValuesA(nil) 104 if !floats.EqualApprox(c, cNone, tol) { 105 t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition") 106 } 107 sNone := gsvd.ValuesB(nil) 108 if !floats.EqualApprox(s, sNone, tol) { 109 t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition") 110 } 111 } 112 }) 113 114 } 115 } 116 117 func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) { 118 s1 = &Dense{} 119 s2 = &Dense{} 120 zR = &Dense{} 121 u = &Dense{} 122 v = &Dense{} 123 q = &Dense{} 124 gsvd.SigmaATo(s1) 125 gsvd.SigmaBTo(s2) 126 gsvd.ZeroRTo(zR) 127 gsvd.UTo(u) 128 gsvd.VTo(v) 129 gsvd.QTo(q) 130 c = gsvd.ValuesA(nil) 131 s = gsvd.ValuesB(nil) 132 return c, s, s1, s2, zR, u, v, q 133 }