gonum.org/v1/gonum@v0.14.0/mat/tridiag_test.go (about) 1 // Copyright ©2021 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "fmt" 9 "reflect" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/lapack/lapack64" 15 ) 16 17 func TestNewTridiag(t *testing.T) { 18 for i, test := range []struct { 19 n int 20 dl, d, du []float64 21 panics bool 22 want *Tridiag 23 dense *Dense 24 }{ 25 { 26 n: 1, 27 dl: nil, 28 d: []float64{1.2}, 29 du: nil, 30 panics: false, 31 want: &Tridiag{ 32 mat: lapack64.Tridiagonal{ 33 N: 1, 34 DL: nil, 35 D: []float64{1.2}, 36 DU: nil, 37 }, 38 }, 39 dense: NewDense(1, 1, []float64{1.2}), 40 }, 41 { 42 n: 1, 43 dl: []float64{}, 44 d: []float64{1.2}, 45 du: []float64{}, 46 panics: false, 47 want: &Tridiag{ 48 mat: lapack64.Tridiagonal{ 49 N: 1, 50 DL: []float64{}, 51 D: []float64{1.2}, 52 DU: []float64{}, 53 }, 54 }, 55 dense: NewDense(1, 1, []float64{1.2}), 56 }, 57 { 58 n: 4, 59 dl: []float64{1.2, 2.3, 3.4}, 60 d: []float64{4.5, 5.6, 6.7, 7.8}, 61 du: []float64{8.9, 9.0, 0.1}, 62 panics: false, 63 want: &Tridiag{ 64 mat: lapack64.Tridiagonal{ 65 N: 4, 66 DL: []float64{1.2, 2.3, 3.4}, 67 D: []float64{4.5, 5.6, 6.7, 7.8}, 68 DU: []float64{8.9, 9.0, 0.1}, 69 }, 70 }, 71 dense: NewDense(4, 4, []float64{ 72 4.5, 8.9, 0, 0, 73 1.2, 5.6, 9.0, 0, 74 0, 2.3, 6.7, 0.1, 75 0, 0, 3.4, 7.8, 76 }), 77 }, 78 { 79 n: 4, 80 dl: nil, 81 d: nil, 82 du: nil, 83 panics: false, 84 want: &Tridiag{ 85 mat: lapack64.Tridiagonal{ 86 N: 4, 87 DL: []float64{0, 0, 0}, 88 D: []float64{0, 0, 0, 0}, 89 DU: []float64{0, 0, 0}, 90 }, 91 }, 92 dense: NewDense(4, 4, nil), 93 }, 94 { 95 n: -1, 96 panics: true, 97 }, 98 { 99 n: 0, 100 panics: true, 101 }, 102 { 103 n: 1, 104 dl: []float64{1.2}, 105 d: nil, 106 du: nil, 107 panics: true, 108 }, 109 { 110 n: 1, 111 dl: nil, 112 d: []float64{1.2, 2.3}, 113 du: nil, 114 panics: true, 115 }, 116 { 117 n: 1, 118 dl: []float64{}, 119 d: nil, 120 du: []float64{}, 121 panics: true, 122 }, 123 { 124 n: 4, 125 dl: []float64{1.2}, 126 d: nil, 127 du: nil, 128 panics: true, 129 }, 130 { 131 n: 4, 132 dl: []float64{1.2, 2.3, 3.4}, 133 d: []float64{4.5, 5.6, 6.7, 7.8, 1.2}, 134 du: []float64{8.9, 9.0, 0.1}, 135 panics: true, 136 }, 137 } { 138 var a *Tridiag 139 panicked, msg := panics(func() { 140 a = NewTridiag(test.n, test.dl, test.d, test.du) 141 }) 142 if panicked { 143 if !test.panics { 144 t.Errorf("Case %d: unexpected panic: %s", i, msg) 145 } 146 continue 147 } 148 if test.panics { 149 t.Errorf("Case %d: expected panic", i) 150 continue 151 } 152 153 r, c := a.Dims() 154 if r != test.n { 155 t.Errorf("Case %d: unexpected number of rows: got=%d want=%d", i, r, test.n) 156 } 157 if c != test.n { 158 t.Errorf("Case %d: unexpected number of columns: got=%d want=%d", i, c, test.n) 159 } 160 161 kl, ku := a.Bandwidth() 162 if kl != 1 || ku != 1 { 163 t.Errorf("Case %d: unexpected bandwidth: got=%d,%d want=1,1", i, kl, ku) 164 } 165 166 if !reflect.DeepEqual(a, test.want) { 167 t.Errorf("Case %d: unexpected value via reflect: got=%v, want=%v", i, a, test.want) 168 } 169 if !Equal(a, test.want) { 170 t.Errorf("Case %d: unexpected value via mat.Equal: got=%v, want=%v", i, a, test.want) 171 } 172 if !Equal(a, test.dense) { 173 t.Errorf("Case %d: unexpected value via mat.Equal(Tridiag,Dense):\ngot:\n% v\nwant:\n% v", i, Formatted(a), Formatted(test.dense)) 174 } 175 } 176 } 177 178 func TestTridiagAtSet(t *testing.T) { 179 t.Parallel() 180 for _, n := range []int{1, 2, 3, 4, 7, 10} { 181 tri, ref := newTestTridiag(n) 182 183 name := fmt.Sprintf("Case n=%v", n) 184 185 // Check At explicitly with all valid indices. 186 for i := 0; i < n; i++ { 187 for j := 0; j < n; j++ { 188 if tri.At(i, j) != ref.At(i, j) { 189 t.Errorf("%v: unexpected value for At(%d,%d): got %v, want %v", 190 name, i, j, tri.At(i, j), ref.At(i, j)) 191 } 192 } 193 } 194 // Check At via a call to Equal. 195 if !Equal(tri, ref) { 196 t.Errorf("%v: unexpected value:\ngot: % v\nwant:% v", 197 name, Formatted(tri, Prefix(" ")), Formatted(ref, Prefix(" "))) 198 } 199 200 // Check At out of bounds. 201 for _, i := range []int{-1, n, n + 1} { 202 for j := 0; j < n; j++ { 203 panicked, message := panics(func() { tri.At(i, j) }) 204 if !panicked || message != ErrRowAccess.Error() { 205 t.Errorf("%v: expected panic for invalid row access at (%d,%d)", name, i, j) 206 } 207 } 208 } 209 for _, j := range []int{-1, n, n + 1} { 210 for i := 0; i < n; i++ { 211 panicked, message := panics(func() { tri.At(i, j) }) 212 if !panicked || message != ErrColAccess.Error() { 213 t.Errorf("%v: expected panic for invalid column access at (%d,%d)", name, i, j) 214 } 215 } 216 } 217 218 // Check SetBand out of bounds. 219 for _, i := range []int{-1, n, n + 1} { 220 for j := 0; j < n; j++ { 221 panicked, message := panics(func() { tri.SetBand(i, j, 1.2) }) 222 if !panicked || message != ErrRowAccess.Error() { 223 t.Errorf("%v: expected panic for invalid row access at (%d,%d)", name, i, j) 224 } 225 } 226 } 227 for _, j := range []int{-1, n, n + 1} { 228 for i := 0; i < n; i++ { 229 panicked, message := panics(func() { tri.SetBand(i, j, 1.2) }) 230 if !panicked || message != ErrColAccess.Error() { 231 t.Errorf("%v: expected panic for invalid column access at (%d,%d)", name, i, j) 232 } 233 } 234 } 235 for i := 0; i < n; i++ { 236 for j := 0; j <= i-2; j++ { 237 panicked, message := panics(func() { tri.SetBand(i, j, 1.2) }) 238 if !panicked || message != ErrBandSet.Error() { 239 t.Errorf("%v: expected panic for invalid access at (%d,%d)", name, i, j) 240 } 241 } 242 for j := i + 2; j < n; j++ { 243 panicked, message := panics(func() { tri.SetBand(i, j, 1.2) }) 244 if !panicked || message != ErrBandSet.Error() { 245 t.Errorf("%v: expected panic for invalid access at (%d,%d)", name, i, j) 246 } 247 } 248 } 249 250 // Check SetBand within bandwidth. 251 for i := 0; i < n; i++ { 252 for j := max(0, i-1); j <= min(i+1, n-1); j++ { 253 want := float64(i*n + j + 100) 254 tri.SetBand(i, j, want) 255 if got := tri.At(i, j); got != want { 256 t.Errorf("%v: unexpected value at (%d,%d) after SetBand: got %v, want %v", name, i, j, got, want) 257 } 258 } 259 } 260 } 261 } 262 263 func newTestTridiag(n int) (*Tridiag, *Dense) { 264 var dl, d, du []float64 265 d = make([]float64, n) 266 if n > 1 { 267 dl = make([]float64, n-1) 268 du = make([]float64, n-1) 269 } 270 for i := range d { 271 d[i] = float64(i*n + i + 1) 272 } 273 for j := range dl { 274 i := j + 1 275 dl[j] = float64(i*n + j + 1) 276 } 277 for i := range du { 278 j := i + 1 279 du[i] = float64(i*n + j + 1) 280 } 281 dense := make([]float64, n*n) 282 for i := 0; i < n; i++ { 283 for j := max(0, i-1); j <= min(i+1, n-1); j++ { 284 dense[i*n+j] = float64(i*n + j + 1) 285 } 286 } 287 return NewTridiag(n, dl, d, du), NewDense(n, n, dense) 288 } 289 290 func TestTridiagReset(t *testing.T) { 291 t.Parallel() 292 for _, n := range []int{1, 2, 3, 4, 7, 10} { 293 a, _ := newTestTridiag(n) 294 if a.IsEmpty() { 295 t.Errorf("Case n=%d: matrix is empty", n) 296 } 297 a.Reset() 298 if !a.IsEmpty() { 299 t.Errorf("Case n=%d: matrix is not empty after Reset", n) 300 } 301 } 302 } 303 304 func TestTridiagDiagView(t *testing.T) { 305 t.Parallel() 306 for _, n := range []int{1, 2, 3, 4, 7, 10} { 307 a, _ := newTestTridiag(n) 308 testDiagView(t, n, a) 309 } 310 } 311 312 func TestTridiagZero(t *testing.T) { 313 t.Parallel() 314 for _, n := range []int{1, 2, 3, 4, 7, 10} { 315 a, _ := newTestTridiag(n) 316 a.Zero() 317 for i := 0; i < n; i++ { 318 for j := 0; j < n; j++ { 319 if a.At(i, j) != 0 { 320 t.Errorf("Case n=%d: unexpected non-zero at (%d,%d): got %f", n, i, j, a.At(i, j)) 321 } 322 } 323 } 324 } 325 } 326 327 func TestTridiagSolveTo(t *testing.T) { 328 t.Parallel() 329 330 const tol = 1e-13 331 332 rnd := rand.New(rand.NewSource(1)) 333 random := func(n int) []float64 { 334 d := make([]float64, n) 335 for i := range d { 336 d[i] = rnd.NormFloat64() 337 } 338 return d 339 } 340 341 for _, n := range []int{1, 2, 3, 4, 7, 10} { 342 a := NewTridiag(n, random(n-1), random(n), random(n-1)) 343 var aDense Dense 344 aDense.CloneFrom(a) 345 for _, trans := range []bool{false, true} { 346 for _, nrhs := range []int{1, 2, 5} { 347 const ( 348 denseB = iota 349 rawB 350 basicB 351 ) 352 for _, bType := range []int{denseB, rawB, basicB} { 353 const ( 354 emptyDst = iota 355 shapedDst 356 bIsDst 357 ) 358 for _, dstType := range []int{emptyDst, shapedDst, bIsDst} { 359 if dstType == bIsDst && bType != denseB { 360 continue 361 } 362 363 var b Matrix 364 switch bType { 365 case denseB: 366 b = NewDense(n, nrhs, random(n*nrhs)) 367 case rawB: 368 b = &rawMatrix{asBasicMatrix(NewDense(n, nrhs, random(n*nrhs)))} 369 case basicB: 370 b = asBasicMatrix(NewDense(n, nrhs, random(n*nrhs))) 371 default: 372 panic("bad bType") 373 } 374 375 var dst *Dense 376 switch dstType { 377 case emptyDst: 378 dst = new(Dense) 379 case shapedDst: 380 dst = NewDense(n, nrhs, random(n*nrhs)) 381 case bIsDst: 382 dst = b.(*Dense) 383 default: 384 panic("bad dstType") 385 } 386 387 name := fmt.Sprintf("n=%d,nrhs=%d,trans=%t,dstType=%d,bType=%d", n, nrhs, trans, dstType, bType) 388 389 var want Dense 390 var err error 391 if !trans { 392 err = want.Solve(&aDense, b) 393 } else { 394 err = want.Solve(aDense.T(), b) 395 } 396 if err != nil { 397 t.Fatalf("%v: unexpected failure when computing reference solution: %v", name, err) 398 } 399 400 err = a.SolveTo(dst, trans, b) 401 if err != nil { 402 t.Fatalf("%v: unexpected failure from Tridiag.SolveTo: %v", name, err) 403 } 404 405 var diff Dense 406 diff.Sub(dst, &want) 407 if resid := Norm(&diff, 1); resid > tol*float64(n) { 408 t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(n)) 409 } 410 } 411 } 412 } 413 } 414 } 415 } 416 417 func TestTridiagSolveVecTo(t *testing.T) { 418 t.Parallel() 419 420 const tol = 1e-13 421 422 rnd := rand.New(rand.NewSource(1)) 423 random := func(n int) []float64 { 424 d := make([]float64, n) 425 for i := range d { 426 d[i] = rnd.NormFloat64() 427 } 428 return d 429 } 430 431 for _, n := range []int{1, 2, 3, 4, 7, 10} { 432 a := NewTridiag(n, random(n-1), random(n), random(n-1)) 433 var aDense Dense 434 aDense.CloneFrom(a) 435 for _, trans := range []bool{false, true} { 436 const ( 437 denseB = iota 438 rawB 439 basicB 440 ) 441 for _, bType := range []int{denseB, rawB, basicB} { 442 const ( 443 emptyDst = iota 444 shapedDst 445 bIsDst 446 ) 447 for _, dstType := range []int{emptyDst, shapedDst, bIsDst} { 448 if dstType == bIsDst && bType != denseB { 449 continue 450 } 451 452 var b Vector 453 switch bType { 454 case denseB: 455 b = NewVecDense(n, random(n)) 456 case rawB: 457 b = &rawVector{asBasicVector(NewVecDense(n, random(n)))} 458 case basicB: 459 b = asBasicVector(NewVecDense(n, random(n))) 460 default: 461 panic("bad bType") 462 } 463 464 var dst *VecDense 465 switch dstType { 466 case emptyDst: 467 dst = new(VecDense) 468 case shapedDst: 469 dst = NewVecDense(n, random(n)) 470 case bIsDst: 471 dst = b.(*VecDense) 472 default: 473 panic("bad dstType") 474 } 475 476 name := fmt.Sprintf("n=%d,trans=%t,dstType=%d,bType=%d", n, trans, dstType, bType) 477 478 var want VecDense 479 var err error 480 if !trans { 481 err = want.SolveVec(&aDense, b) 482 } else { 483 err = want.SolveVec(aDense.T(), b) 484 } 485 if err != nil { 486 t.Fatalf("%v: unexpected failure when computing reference solution: %v", name, err) 487 } 488 489 err = a.SolveVecTo(dst, trans, b) 490 if err != nil { 491 t.Fatalf("%v: unexpected failure from Tridiag.SolveTo: %v", name, err) 492 } 493 494 var diff Dense 495 diff.Sub(dst, &want) 496 if resid := Norm(&diff, 1); resid > tol*float64(n) { 497 t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(n)) 498 } 499 } 500 } 501 } 502 } 503 }