gonum.org/v1/gonum@v0.14.0/mat/solve_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 ) 12 13 func TestSolve(t *testing.T) { 14 t.Parallel() 15 rnd := rand.New(rand.NewSource(1)) 16 // Hand-coded cases. 17 for _, test := range []struct { 18 a [][]float64 19 b [][]float64 20 ans [][]float64 21 shouldErr bool 22 }{ 23 { 24 a: [][]float64{{6}}, 25 b: [][]float64{{3}}, 26 ans: [][]float64{{0.5}}, 27 shouldErr: false, 28 }, 29 { 30 a: [][]float64{ 31 {1, 0, 0}, 32 {0, 1, 0}, 33 {0, 0, 1}, 34 }, 35 b: [][]float64{ 36 {3}, 37 {2}, 38 {1}, 39 }, 40 ans: [][]float64{ 41 {3}, 42 {2}, 43 {1}, 44 }, 45 shouldErr: false, 46 }, 47 { 48 a: [][]float64{ 49 {0.8147, 0.9134, 0.5528}, 50 {0.9058, 0.6324, 0.8723}, 51 {0.1270, 0.0975, 0.7612}, 52 }, 53 b: [][]float64{ 54 {0.278}, 55 {0.547}, 56 {0.958}, 57 }, 58 ans: [][]float64{ 59 {-0.932687281002860}, 60 {0.303963920182067}, 61 {1.375216503507109}, 62 }, 63 shouldErr: false, 64 }, 65 { 66 a: [][]float64{ 67 {0.8147, 0.9134, 0.5528}, 68 {0.9058, 0.6324, 0.8723}, 69 }, 70 b: [][]float64{ 71 {0.278}, 72 {0.547}, 73 }, 74 ans: [][]float64{ 75 {0.25919787248965376}, 76 {-0.25560256266441034}, 77 {0.5432324059702451}, 78 }, 79 shouldErr: false, 80 }, 81 { 82 a: [][]float64{ 83 {0.8147, 0.9134, 0.9}, 84 {0.9058, 0.6324, 0.9}, 85 {0.1270, 0.0975, 0.1}, 86 {1.6, 2.8, -3.5}, 87 }, 88 b: [][]float64{ 89 {0.278}, 90 {0.547}, 91 {-0.958}, 92 {1.452}, 93 }, 94 ans: [][]float64{ 95 {0.820970340787782}, 96 {-0.218604626527306}, 97 {-0.212938815234215}, 98 }, 99 shouldErr: false, 100 }, 101 { 102 a: [][]float64{ 103 {0.8147, 0.9134, 0.231, -1.65}, 104 {0.9058, 0.6324, 0.9, 0.72}, 105 {0.1270, 0.0975, 0.1, 1.723}, 106 {1.6, 2.8, -3.5, 0.987}, 107 {7.231, 9.154, 1.823, 0.9}, 108 }, 109 b: [][]float64{ 110 {0.278, 8.635}, 111 {0.547, 9.125}, 112 {-0.958, -0.762}, 113 {1.452, 1.444}, 114 {1.999, -7.234}, 115 }, 116 ans: [][]float64{ 117 {1.863006789511373, 44.467887791812750}, 118 {-1.127270935407224, -34.073794226035126}, 119 {-0.527926457947330, -8.032133759788573}, 120 {-0.248621916204897, -2.366366415805275}, 121 }, 122 shouldErr: false, 123 }, 124 { 125 a: [][]float64{ 126 {0, 0}, 127 {0, 0}, 128 }, 129 b: [][]float64{ 130 {3}, 131 {2}, 132 }, 133 ans: nil, 134 shouldErr: true, 135 }, 136 { 137 a: [][]float64{ 138 {0, 0}, 139 {0, 0}, 140 {0, 0}, 141 }, 142 b: [][]float64{ 143 {3}, 144 {2}, 145 {1}, 146 }, 147 ans: nil, 148 shouldErr: true, 149 }, 150 { 151 a: [][]float64{ 152 {0, 0, 0}, 153 {0, 0, 0}, 154 }, 155 b: [][]float64{ 156 {3}, 157 {2}, 158 }, 159 ans: nil, 160 shouldErr: true, 161 }, 162 } { 163 a := NewDense(flatten(test.a)) 164 b := NewDense(flatten(test.b)) 165 166 var ans *Dense 167 if test.ans != nil { 168 ans = NewDense(flatten(test.ans)) 169 } 170 171 var x Dense 172 err := x.Solve(a, b) 173 if err != nil { 174 if !test.shouldErr { 175 t.Errorf("Unexpected solve error: %s", err) 176 } 177 continue 178 } 179 if err == nil && test.shouldErr { 180 t.Errorf("Did not error during solve.") 181 continue 182 } 183 if !EqualApprox(&x, ans, 1e-12) { 184 t.Errorf("Solve answer mismatch. Want %v, got %v", ans, x) 185 } 186 } 187 188 // Random Cases. 189 for _, test := range []struct { 190 m, n, bc int 191 }{ 192 {5, 5, 1}, 193 {5, 10, 1}, 194 {10, 5, 1}, 195 {5, 5, 7}, 196 {5, 10, 7}, 197 {10, 5, 7}, 198 {5, 5, 12}, 199 {5, 10, 12}, 200 {10, 5, 12}, 201 } { 202 m := test.m 203 n := test.n 204 bc := test.bc 205 a := NewDense(m, n, nil) 206 for i := 0; i < m; i++ { 207 for j := 0; j < n; j++ { 208 a.Set(i, j, rnd.Float64()) 209 } 210 } 211 br := m 212 b := NewDense(br, bc, nil) 213 for i := 0; i < br; i++ { 214 for j := 0; j < bc; j++ { 215 b.Set(i, j, rnd.Float64()) 216 } 217 } 218 var x Dense 219 err := x.Solve(a, b) 220 if err != nil { 221 t.Errorf("unexpected error from dense solve: %v", err) 222 } 223 224 // Test that the normal equations hold. 225 // Aᵀ * A * x = Aᵀ * b 226 var tmp, lhs, rhs Dense 227 tmp.Mul(a.T(), a) 228 lhs.Mul(&tmp, &x) 229 rhs.Mul(a.T(), b) 230 if !EqualApprox(&lhs, &rhs, 1e-10) { 231 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) 232 } 233 } 234 235 // Use testTwoInput. 236 method := func(receiver, a, b Matrix) { 237 type Solver interface { 238 Solve(a, b Matrix) error 239 } 240 rd := receiver.(Solver) 241 _ = rd.Solve(a, b) 242 } 243 denseComparison := func(receiver, a, b *Dense) { 244 _ = receiver.Solve(a, b) 245 } 246 testTwoInput(t, "Solve", &Dense{}, method, denseComparison, legalTypesAll, legalSizeSolve, 1e-7) 247 } 248 249 func TestSolveVec(t *testing.T) { 250 t.Parallel() 251 rnd := rand.New(rand.NewSource(1)) 252 for _, test := range []struct { 253 m, n int 254 }{ 255 {5, 5}, 256 {5, 10}, 257 {10, 5}, 258 {5, 5}, 259 {5, 10}, 260 {10, 5}, 261 {5, 5}, 262 {5, 10}, 263 {10, 5}, 264 } { 265 m := test.m 266 n := test.n 267 a := NewDense(m, n, nil) 268 for i := 0; i < m; i++ { 269 for j := 0; j < n; j++ { 270 a.Set(i, j, rnd.Float64()) 271 } 272 } 273 br := m 274 b := NewVecDense(br, nil) 275 for i := 0; i < br; i++ { 276 b.SetVec(i, rnd.Float64()) 277 } 278 var x VecDense 279 err := x.SolveVec(a, b) 280 if err != nil { 281 t.Errorf("unexpected error from dense vector solve: %v", err) 282 } 283 284 // Test that the normal equations hold. 285 // Aᵀ * A * x = Aᵀ * b 286 var tmp, lhs, rhs Dense 287 tmp.Mul(a.T(), a) 288 lhs.Mul(&tmp, &x) 289 rhs.Mul(a.T(), b) 290 if !EqualApprox(&lhs, &rhs, 1e-10) { 291 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) 292 } 293 } 294 295 // Use testTwoInput 296 method := func(receiver, a, b Matrix) { 297 type SolveVecer interface { 298 SolveVec(a Matrix, b Vector) error 299 } 300 rd := receiver.(SolveVecer) 301 _ = rd.SolveVec(a, b.(Vector)) 302 } 303 denseComparison := func(receiver, a, b *Dense) { 304 _ = receiver.Solve(a, b) 305 } 306 testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12) 307 }