github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/triband_test.go (about) 1 // Copyright ©2018 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "fmt" 9 "reflect" 10 "testing" 11 12 "github.com/jingcheng-WU/gonum/blas" 13 "github.com/jingcheng-WU/gonum/blas/blas64" 14 ) 15 16 func TestNewTriBand(t *testing.T) { 17 t.Parallel() 18 for cas, test := range []struct { 19 data []float64 20 n, k int 21 kind TriKind 22 mat *TriBandDense 23 dense *Dense 24 }{ 25 { 26 data: []float64{1, 2, 3}, 27 n: 3, k: 0, 28 kind: Upper, 29 mat: &TriBandDense{ 30 mat: blas64.TriangularBand{ 31 Diag: blas.NonUnit, 32 Uplo: blas.Upper, 33 N: 3, K: 0, 34 Data: []float64{1, 2, 3}, 35 Stride: 1, 36 }, 37 }, 38 dense: NewDense(3, 3, []float64{ 39 1, 0, 0, 40 0, 2, 0, 41 0, 0, 3, 42 }), 43 }, 44 { 45 data: []float64{ 46 1, 2, 47 3, 4, 48 5, 6, 49 7, 8, 50 9, 10, 51 11, -1, 52 }, 53 n: 6, k: 1, 54 kind: Upper, 55 mat: &TriBandDense{ 56 mat: blas64.TriangularBand{ 57 Diag: blas.NonUnit, 58 Uplo: blas.Upper, 59 N: 6, K: 1, 60 Data: []float64{ 61 1, 2, 62 3, 4, 63 5, 6, 64 7, 8, 65 9, 10, 66 11, -1, 67 }, 68 Stride: 2, 69 }, 70 }, 71 dense: NewDense(6, 6, []float64{ 72 1, 2, 0, 0, 0, 0, 73 0, 3, 4, 0, 0, 0, 74 0, 0, 5, 6, 0, 0, 75 0, 0, 0, 7, 8, 0, 76 0, 0, 0, 0, 9, 10, 77 0, 0, 0, 0, 0, 11, 78 }), 79 }, 80 { 81 data: []float64{ 82 1, 2, 3, 83 4, 5, 6, 84 7, 8, 9, 85 10, 11, 12, 86 13, 14, -1, 87 15, -1, -1, 88 }, 89 n: 6, k: 2, 90 kind: Upper, 91 mat: &TriBandDense{ 92 mat: blas64.TriangularBand{ 93 Diag: blas.NonUnit, 94 Uplo: blas.Upper, 95 N: 6, K: 2, 96 Data: []float64{ 97 1, 2, 3, 98 4, 5, 6, 99 7, 8, 9, 100 10, 11, 12, 101 13, 14, -1, 102 15, -1, -1, 103 }, 104 Stride: 3, 105 }, 106 }, 107 dense: NewDense(6, 6, []float64{ 108 1, 2, 3, 0, 0, 0, 109 0, 4, 5, 6, 0, 0, 110 0, 0, 7, 8, 9, 0, 111 0, 0, 0, 10, 11, 12, 112 0, 0, 0, 0, 13, 14, 113 0, 0, 0, 0, 0, 15, 114 }), 115 }, 116 { 117 data: []float64{ 118 -1, 1, 119 2, 3, 120 4, 5, 121 6, 7, 122 8, 9, 123 10, 11, 124 }, 125 n: 6, k: 1, 126 kind: Lower, 127 mat: &TriBandDense{ 128 mat: blas64.TriangularBand{ 129 Diag: blas.NonUnit, 130 Uplo: blas.Lower, 131 N: 6, K: 1, 132 Data: []float64{ 133 -1, 1, 134 2, 3, 135 4, 5, 136 6, 7, 137 8, 9, 138 10, 11, 139 }, 140 Stride: 2, 141 }, 142 }, 143 dense: NewDense(6, 6, []float64{ 144 1, 0, 0, 0, 0, 0, 145 2, 3, 0, 0, 0, 0, 146 0, 4, 5, 0, 0, 0, 147 0, 0, 6, 7, 0, 0, 148 0, 0, 0, 8, 9, 0, 149 0, 0, 0, 0, 10, 11, 150 }), 151 }, 152 { 153 data: []float64{ 154 -1, -1, 1, 155 -1, 2, 3, 156 4, 5, 6, 157 7, 8, 9, 158 10, 11, 12, 159 13, 14, 15, 160 }, 161 n: 6, k: 2, 162 kind: Lower, 163 mat: &TriBandDense{ 164 mat: blas64.TriangularBand{ 165 Diag: blas.NonUnit, 166 Uplo: blas.Lower, 167 N: 6, K: 2, 168 Data: []float64{ 169 -1, -1, 1, 170 -1, 2, 3, 171 4, 5, 6, 172 7, 8, 9, 173 10, 11, 12, 174 13, 14, 15, 175 }, 176 Stride: 3, 177 }, 178 }, 179 dense: NewDense(6, 6, []float64{ 180 1, 0, 0, 0, 0, 0, 181 2, 3, 0, 0, 0, 0, 182 4, 5, 6, 0, 0, 0, 183 0, 7, 8, 9, 0, 0, 184 0, 0, 10, 11, 12, 0, 185 0, 0, 0, 13, 14, 15, 186 }), 187 }, 188 } { 189 triBand := NewTriBandDense(test.n, test.k, test.kind, test.data) 190 r, c := triBand.Dims() 191 n, k, kind := triBand.TriBand() 192 if n != test.n { 193 t.Errorf("unexpected triband size for test %d: got: %d want: %d", cas, n, test.n) 194 } 195 if k != test.k { 196 t.Errorf("unexpected triband bandwidth for test %d: got: %d want: %d", cas, k, test.k) 197 } 198 if kind != test.kind { 199 t.Errorf("unexpected triband bandwidth for test %v: got: %v want: %v", cas, kind, test.kind) 200 } 201 if r != n { 202 t.Errorf("unexpected number of rows for test %d: got: %d want: %d", cas, r, n) 203 } 204 if c != n { 205 t.Errorf("unexpected number of cols for test %d: got: %d want: %d", cas, c, n) 206 } 207 if !reflect.DeepEqual(triBand, test.mat) { 208 t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", cas, triBand, test.mat) 209 } 210 if !Equal(triBand, test.mat) { 211 t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", cas, triBand, test.mat) 212 } 213 if !Equal(triBand, test.dense) { 214 t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", cas, Formatted(triBand), Formatted(test.dense)) 215 } 216 } 217 } 218 219 func TestTriBandAtSetUpper(t *testing.T) { 220 t.Parallel() 221 for _, kind := range []TriKind{Upper, Lower} { 222 var band *TriBandDense 223 var data []float64 224 if kind { 225 // 1 2 3 0 0 0 226 // 0 4 5 6 0 0 227 // 0 0 7 8 9 0 228 // 0 0 0 10 11 12 229 // 0 0 0 0 13 14 230 // 0 0 0 0 0 15 231 data = []float64{ 232 1, 2, 3, 233 4, 5, 6, 234 7, 8, 9, 235 10, 11, 12, 236 13, 14, -1, 237 15, -1, -1, 238 } 239 band = NewTriBandDense(6, 2, kind, data) 240 } else { 241 // 1 0 0 0 0 0 242 // 2 3 0 0 0 0 243 // 4 5 6 0 0 0 244 // 0 7 8 9 0 0 245 // 0 0 10 11 12 0 246 // 0 0 0 13 14 15 247 data = []float64{ 248 -1, -1, 1, 249 -1, 2, 3, 250 4, 5, 6, 251 7, 8, 9, 252 10, 11, 12, 253 13, 14, 15, 254 } 255 band = NewTriBandDense(6, 2, kind, data) 256 } 257 258 rows, cols := band.Dims() 259 260 // Check At out of bounds. 261 for _, row := range []int{-1, rows, rows + 1} { 262 panicked, message := panics(func() { band.At(row, 0) }) 263 if !panicked || message != ErrRowAccess.Error() { 264 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 265 } 266 } 267 for _, col := range []int{-1, cols, cols + 1} { 268 panicked, message := panics(func() { band.At(0, col) }) 269 if !panicked || message != ErrColAccess.Error() { 270 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 271 } 272 } 273 274 // Check Set out of bounds 275 // First, check outside the matrix bounds. 276 for _, row := range []int{-1, rows, rows + 1} { 277 panicked, message := panics(func() { band.SetTriBand(row, 0, 1.2) }) 278 if !panicked || message != ErrRowAccess.Error() { 279 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 280 } 281 } 282 for _, col := range []int{-1, cols, cols + 1} { 283 panicked, message := panics(func() { band.SetTriBand(0, col, 1.2) }) 284 if !panicked || message != ErrColAccess.Error() { 285 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 286 } 287 } 288 // Next, check outside the Triangular bounds. 289 for _, s := range []struct{ r, c int }{ 290 {3, 2}, 291 } { 292 if kind == Lower { 293 s.r, s.c = s.c, s.r 294 } 295 panicked, message := panics(func() { band.SetTriBand(s.r, s.c, 1.2) }) 296 if !panicked || message != ErrTriangleSet.Error() { 297 t.Errorf("expected panic for invalid triangular access N=%d, r=%d c=%d", cols, s.r, s.c) 298 } 299 } 300 // Finally, check inside the triangle, but outside the band. 301 for _, s := range []struct{ r, c int }{ 302 {1, 5}, 303 } { 304 if kind == Lower { 305 s.r, s.c = s.c, s.r 306 } 307 panicked, message := panics(func() { band.SetTriBand(s.r, s.c, 1.2) }) 308 if !panicked || message != ErrBandSet.Error() { 309 t.Errorf("expected panic for invalid triangular access N=%d, r=%d c=%d", cols, s.r, s.c) 310 } 311 } 312 313 // Test that At and Set work correctly. 314 offset := 100.0 315 dataCopy := make([]float64, len(data)) 316 copy(dataCopy, data) 317 for i := 0; i < rows; i++ { 318 for j := 0; j < rows; j++ { 319 v := band.At(i, j) 320 if v != 0 { 321 band.SetTriBand(i, j, v+offset) 322 } 323 } 324 } 325 for i, v := range dataCopy { 326 if v == -1 { 327 if data[i] != -1 { 328 t.Errorf("Set changed unexpected entry. Want %v, got %v", -1, data[i]) 329 } 330 } else { 331 if v != data[i]-offset { 332 t.Errorf("Set incorrectly changed for %v. got %v, want %v", v, data[i], v+offset) 333 } 334 } 335 } 336 } 337 } 338 339 func TestTriBandDenseZero(t *testing.T) { 340 t.Parallel() 341 // Elements that equal 1 should be set to zero, elements that equal -1 342 // should remain unchanged. 343 for _, test := range []*TriBandDense{ 344 { 345 mat: blas64.TriangularBand{ 346 Uplo: blas.Upper, 347 N: 6, 348 K: 2, 349 Stride: 5, 350 Data: []float64{ 351 1, 1, 1, -1, -1, 352 1, 1, 1, -1, -1, 353 1, 1, 1, -1, -1, 354 1, 1, 1, -1, -1, 355 1, 1, -1, -1, -1, 356 1, -1, -1, -1, -1, 357 }, 358 }, 359 }, 360 { 361 mat: blas64.TriangularBand{ 362 Uplo: blas.Lower, 363 N: 6, 364 K: 2, 365 Stride: 5, 366 Data: []float64{ 367 -1, -1, 1, -1, -1, 368 -1, 1, 1, -1, -1, 369 1, 1, 1, -1, -1, 370 1, 1, 1, -1, -1, 371 1, 1, 1, -1, -1, 372 1, 1, 1, -1, -1, 373 }, 374 }, 375 }, 376 } { 377 dataCopy := make([]float64, len(test.mat.Data)) 378 copy(dataCopy, test.mat.Data) 379 test.Zero() 380 for i, v := range test.mat.Data { 381 if dataCopy[i] != -1 && v != 0 { 382 t.Errorf("Matrix not zeroed in bounds") 383 } 384 if dataCopy[i] == -1 && v != -1 { 385 t.Errorf("Matrix zeroed out of bounds") 386 } 387 } 388 } 389 } 390 391 func TestTriBandDiagView(t *testing.T) { 392 t.Parallel() 393 for cas, test := range []*TriBandDense{ 394 NewTriBandDense(1, 0, Upper, []float64{1}), 395 NewTriBandDense(4, 0, Upper, []float64{1, 2, 3, 4}), 396 NewTriBandDense(6, 2, Upper, []float64{ 397 1, 2, 3, 398 4, 5, 6, 399 7, 8, 9, 400 10, 11, 12, 401 13, 14, -1, 402 15, -1, -1, 403 }), 404 NewTriBandDense(1, 0, Lower, []float64{1}), 405 NewTriBandDense(4, 0, Lower, []float64{1, 2, 3, 4}), 406 NewTriBandDense(6, 2, Lower, []float64{ 407 -1, -1, 1, 408 -1, 2, 3, 409 4, 5, 6, 410 7, 8, 9, 411 10, 11, 12, 412 13, 14, 15, 413 }), 414 } { 415 testDiagView(t, cas, test) 416 } 417 } 418 419 func TestTriBandDenseSolveTo(t *testing.T) { 420 t.Parallel() 421 422 const tol = 1e-15 423 424 for tc, test := range []struct { 425 a *TriBandDense 426 b *Dense 427 }{ 428 { 429 a: NewTriBandDense(5, 2, Upper, []float64{ 430 -0.34, -0.49, -0.51, 431 -0.25, -0.5, 1.03, 432 -1.1, 0.3, -0.82, 433 1.69, 0.69, -2.22, 434 -0.62, 1.22, -0.85, 435 }), 436 b: NewDense(5, 2, []float64{ 437 0.44, 1.34, 438 0.07, -1.45, 439 -0.32, -0.88, 440 -0.09, -0.15, 441 -1.17, -0.19, 442 }), 443 }, 444 { 445 a: NewTriBandDense(5, 2, Lower, []float64{ 446 0, 0, -0.34, 447 0, -0.49, -0.25, 448 -0.51, -0.5, -1.1, 449 1.03, 0.3, 1.69, 450 -0.82, 0.69, -0.62, 451 }), 452 b: NewDense(5, 2, []float64{ 453 0.44, 1.34, 454 0.07, -1.45, 455 -0.32, -0.88, 456 -0.09, -0.15, 457 -1.17, -0.19, 458 }), 459 }, 460 } { 461 a := test.a 462 for _, trans := range []bool{false, true} { 463 for _, dstSameAsB := range []bool{false, true} { 464 name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB) 465 466 n, nrhs := test.b.Dims() 467 var dst Dense 468 var err error 469 if dstSameAsB { 470 dst = *NewDense(n, nrhs, nil) 471 dst.Copy(test.b) 472 err = a.SolveTo(&dst, trans, &dst) 473 } else { 474 tmp := NewDense(n, nrhs, nil) 475 tmp.Copy(test.b) 476 err = a.SolveTo(&dst, trans, asBasicMatrix(tmp)) 477 } 478 479 if err != nil { 480 t.Fatalf("%v: unexpected error from SolveTo", name) 481 } 482 483 var resid Dense 484 if trans { 485 resid.Mul(a.T(), &dst) 486 } else { 487 resid.Mul(a, &dst) 488 } 489 resid.Sub(&resid, test.b) 490 diff := Norm(&resid, 1) 491 if diff > tol { 492 t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol) 493 } 494 } 495 } 496 } 497 } 498 499 func TestTriBandDenseSolveVecTo(t *testing.T) { 500 t.Parallel() 501 502 const tol = 1e-15 503 504 for tc, test := range []struct { 505 a *TriBandDense 506 b *VecDense 507 }{ 508 { 509 a: NewTriBandDense(5, 2, Upper, []float64{ 510 -0.34, -0.49, -0.51, 511 -0.25, -0.5, 1.03, 512 -1.1, 0.3, -0.82, 513 1.69, 0.69, -2.22, 514 -0.62, 1.22, -0.85, 515 }), 516 b: NewVecDense(5, []float64{ 517 0.44, 518 0.07, 519 -0.32, 520 -0.09, 521 -1.17, 522 }), 523 }, 524 { 525 a: NewTriBandDense(5, 2, Lower, []float64{ 526 0, 0, -0.34, 527 0, -0.49, -0.25, 528 -0.51, -0.5, -1.1, 529 1.03, 0.3, 1.69, 530 -0.82, 0.69, -0.62, 531 }), 532 b: NewVecDense(5, []float64{ 533 0.44, 534 0.07, 535 -0.32, 536 -0.09, 537 -1.17, 538 }), 539 }, 540 } { 541 a := test.a 542 for _, trans := range []bool{false, true} { 543 for _, dstSameAsB := range []bool{false, true} { 544 name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB) 545 546 n, _ := test.b.Dims() 547 var dst VecDense 548 var err error 549 if dstSameAsB { 550 dst = *NewVecDense(n, nil) 551 dst.CopyVec(test.b) 552 err = a.SolveVecTo(&dst, trans, &dst) 553 } else { 554 tmp := NewVecDense(n, nil) 555 tmp.CopyVec(test.b) 556 err = a.SolveVecTo(&dst, trans, asBasicVector(tmp)) 557 } 558 559 if err != nil { 560 t.Fatalf("%v: unexpected error from SolveVecTo", name) 561 } 562 563 var resid VecDense 564 if trans { 565 resid.MulVec(a.T(), &dst) 566 } else { 567 resid.MulVec(a, &dst) 568 } 569 resid.SubVec(&resid, test.b) 570 diff := Norm(&resid, 1) 571 if diff > tol { 572 t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol) 573 } 574 } 575 } 576 } 577 }