github.com/wzzhu/tensor@v0.9.24/dense_svd_test.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/pkg/errors" 8 "gonum.org/v1/gonum/mat" 9 ) 10 11 // tests for SVD adapted from Gonum's SVD tests. 12 // Gonum's licence is listed at https://gonum.org/v1/gonum/license 13 14 var svdtestsThin = []struct { 15 data []float64 16 shape Shape 17 18 correctSData []float64 19 correctSShape Shape 20 21 correctUData []float64 22 correctUShape Shape 23 24 correctVData []float64 25 correctVShape Shape 26 }{ 27 { 28 []float64{2, 4, 1, 3, 0, 0, 0, 0}, Shape{4, 2}, 29 []float64{5.464985704219041, 0.365966190626258}, Shape{2}, 30 []float64{-0.8174155604703632, -0.5760484367663209, -0.5760484367663209, 0.8174155604703633, 0, 0, 0, 0}, Shape{4, 2}, 31 []float64{-0.4045535848337571, -0.9145142956773044, -0.9145142956773044, 0.4045535848337571}, Shape{2, 2}, 32 }, 33 34 { 35 []float64{1, 1, 0, 1, 0, 0, 0, 0, 0, 11, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 12, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 13, 3}, Shape{3, 11}, 36 []float64{21.259500881097434, 1.5415021616856566, 1.2873979074613628}, Shape{3}, 37 []float64{-0.5224167862273765, 0.7864430360363114, 0.3295270133658976, -0.5739526766688285, -0.03852203026050301, -0.8179818935216693, -0.6306021141833781, -0.6164603833618163, 0.4715056408282468}, Shape{3, 3}, 38 []float64{ 39 -0.08123293141915189, 0.08528085505260324, -0.013165501690885152, 40 -0.05423546426886932, 0.1102707844980355, 0.622210623111631, 41 0, 0, 0, 42 -0.0245733326078166, 0.510179651760153, 0.25596360803140994, 43 0, 0, 0, 44 0, 0, 0, 45 -0.026997467150282436, -0.024989929445430496, -0.6353761248025164, 46 0, 0, 0, 47 -0.029662131661052707, -0.3999088672621176, 0.3662470150802212, 48 -0.9798839760830571, 0.11328174160898856, -0.047702613241813366, 49 -0.16755466189153964, -0.7395268089170608, 0.08395240366704032}, Shape{11, 3}, 50 }, 51 } 52 53 var svdtestsFull = []Shape{ 54 {5, 5}, 55 {5, 3}, 56 {3, 5}, 57 {150, 150}, 58 {200, 150}, 59 {150, 200}, 60 } 61 62 // calculate corrects 63 func calcSigma(s, T *Dense, shape Shape) (sigma *Dense, err error) { 64 sigma = New(Of(Float64), WithShape(shape...)) 65 for i := 0; i < MinInt(shape[0], shape[1]); i++ { 66 var idx int 67 if idx, err = Ltoi(sigma.Shape(), sigma.Strides(), i, i); err != nil { 68 return 69 } 70 sigma.Float64s()[idx] = s.Float64s()[i] 71 } 72 73 return 74 } 75 76 // test svd by doing the SVD, then calculating the corrects 77 func testSVD(T, T2, s, u, v *Dense, t string, i int) (err error) { 78 var sigma, reconstructed *Dense 79 80 if !allClose(T2.Data(), T.Data(), closeenoughf64) { 81 return errors.Errorf("A call to SVD modified the underlying data! %s Test %d", t, i) 82 } 83 84 shape := T2.Shape() 85 if t == "thin" { 86 shape = Shape{MinInt(shape[0], shape[1]), MinInt(shape[0], shape[1])} 87 } 88 89 if sigma, err = calcSigma(s, T, shape); err != nil { 90 return 91 } 92 v.T() 93 94 if reconstructed, err = u.MatMul(sigma, UseSafe()); err != nil { 95 return 96 } 97 if reconstructed, err = reconstructed.MatMul(v, UseSafe()); err != nil { 98 return 99 } 100 101 if !allClose(T2.Data(), reconstructed.Data(), closeenoughf64) { 102 return errors.Errorf("Expected reconstructed to be %v. Got %v instead", T2.Data(), reconstructed.Data()) 103 } 104 return nil 105 } 106 107 func ExampleDense_SVD() { 108 T := New( 109 WithShape(4, 5), 110 WithBacking([]float64{1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}), 111 ) 112 _, u, _, _ := T.SVD(true, true) 113 uT := u.Clone().(*Dense) 114 uT.T() 115 eye, err := u.MatMul(uT) 116 fmt.Println(eye) 117 fmt.Println(err) 118 119 // Output: 120 // ⎡1 0 0 0⎤ 121 // ⎢0 1 0 0⎥ 122 // ⎢0 0 1 0⎥ 123 // ⎣0 0 0 1⎦ 124 // 125 // <nil> 126 } 127 128 func TestDense_SVD(t *testing.T) { 129 var T, T2, s, u, v *Dense 130 var err error 131 132 // gonum specific thin special cases 133 for i, stts := range svdtestsThin { 134 T = New(WithShape(stts.shape...), WithBacking(stts.data)) 135 T2 = T.Clone().(*Dense) 136 137 if s, u, v, err = T.SVD(true, false); err != nil { 138 t.Error(err) 139 continue 140 } 141 142 if !allClose(T2.Data(), T.Data(), closeenoughf64) { 143 t.Errorf("A call to SVD modified the underlying data! Thin Test %d", i) 144 continue 145 } 146 147 if !allClose(stts.correctSData, s.Data(), closeenoughf64) { 148 t.Errorf("Expected s = %v. Got %v instead", stts.correctSData, s.Data()) 149 } 150 151 if !allClose(stts.correctUData, u.Data(), closeenoughf64) { 152 t.Errorf("Expected u = %v. Got %v instead", stts.correctUData, u.Data()) 153 } 154 155 if !allClose(stts.correctVData, v.Data(), closeenoughf64) { 156 t.Errorf("Expected v = %v. Got %v instead", stts.correctVData, v.Data()) 157 } 158 } 159 // standard tests 160 for i, stfs := range svdtestsFull { 161 T = New(WithShape(stfs...), WithBacking(Random(Float64, stfs.TotalSize()))) 162 T2 = T.Clone().(*Dense) 163 164 // full 165 if s, u, v, err = T.SVD(true, true); err != nil { 166 t.Error(err) 167 fmt.Println(err) 168 continue 169 } 170 if err = testSVD(T, T2, s, u, v, "full", i); err != nil { 171 t.Error(err) 172 fmt.Println(err) 173 continue 174 } 175 // thin 176 if s, u, v, err = T.SVD(true, false); err != nil { 177 t.Error(err) 178 continue 179 } 180 181 if err = testSVD(T, T2, s, u, v, "thin", i); err != nil { 182 t.Error(err) 183 continue 184 } 185 186 // none 187 if s, u, v, err = T.SVD(false, false); err != nil { 188 t.Error(err) 189 continue 190 } 191 192 var svd mat.SVD 193 var m *mat.Dense 194 if m, err = ToMat64(T); err != nil { 195 t.Error(err) 196 continue 197 } 198 199 if !svd.Factorize(m, mat.SVDFull) { 200 t.Errorf("Unable to factorise %v", m) 201 continue 202 } 203 204 if !allClose(s.Data(), svd.Values(nil), closeenoughf64) { 205 t.Errorf("Singular value mismatch between Full and None decomposition. Expected %v. Got %v instead", svd.Values(nil), s.Data()) 206 } 207 208 } 209 // this is illogical 210 T = New(Of(Float64), WithShape(2, 2)) 211 if _, _, _, err = T.SVD(false, true); err == nil { 212 t.Errorf("Expected an error!") 213 } 214 215 // if you do this, it is bad and you should feel bad 216 T = New(Of(Float64), WithShape(2, 3, 4)) 217 if _, _, _, err = T.SVD(true, true); err == nil { 218 t.Errorf("Expecetd an error: cannot SVD() a Tensor > 2 dimensions") 219 } 220 221 T = New(Of(Float64), WithShape(2)) 222 if _, _, _, err = T.SVD(true, true); err == nil { 223 t.Errorf("Expecetd an error: cannot SVD() a Tensor < 2 dimensions") 224 } 225 }