github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/list_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //nolint:deadcode,unused 6 package mat 7 8 import ( 9 "fmt" 10 "math" 11 "reflect" 12 "testing" 13 14 "golang.org/x/exp/rand" 15 16 "github.com/jingcheng-WU/gonum/blas" 17 "github.com/jingcheng-WU/gonum/blas/blas64" 18 "github.com/jingcheng-WU/gonum/floats" 19 "github.com/jingcheng-WU/gonum/floats/scalar" 20 ) 21 22 // legalSizeSameRectangular returns whether the two matrices have the same rectangular shape. 23 func legalSizeSameRectangular(ar, ac, br, bc int) bool { 24 if ar != br { 25 return false 26 } 27 if ac != bc { 28 return false 29 } 30 return true 31 } 32 33 // legalSizeSameSquare returns whether the two matrices have the same square shape. 34 func legalSizeSameSquare(ar, ac, br, bc int) bool { 35 if ar != br { 36 return false 37 } 38 if ac != bc { 39 return false 40 } 41 if ar != ac { 42 return false 43 } 44 return true 45 } 46 47 // legalSizeSameHeight returns whether the two matrices have the same number of rows. 48 func legalSizeSameHeight(ar, _, br, _ int) bool { 49 return ar == br 50 } 51 52 // legalSizeSameWidth returns whether the two matrices have the same number of columns. 53 func legalSizeSameWidth(_, ac, _, bc int) bool { 54 return ac == bc 55 } 56 57 // legalSizeSolve returns whether the two matrices can be used in a linear solve. 58 func legalSizeSolve(ar, ac, br, bc int) bool { 59 return ar == br 60 } 61 62 // legalSizeSameVec returns whether the two matrices are column vectors. 63 func legalSizeVector(_, ac, _, bc int) bool { 64 return ac == 1 && bc == 1 65 } 66 67 // legalSizeSameVec returns whether the two matrices are column vectors of the 68 // same dimension. 69 func legalSizeSameVec(ar, ac, br, bc int) bool { 70 return ac == 1 && bc == 1 && ar == br 71 } 72 73 // isAnySize returns true for all matrix sizes. 74 func isAnySize(ar, ac int) bool { 75 return true 76 } 77 78 // isAnySize2 returns true for all matrix sizes. 79 func isAnySize2(ar, ac, br, bc int) bool { 80 return true 81 } 82 83 // isAnyColumnVector returns true for any column vector sizes. 84 func isAnyColumnVector(ar, ac int) bool { 85 return ac == 1 86 } 87 88 // isSquare returns whether the input matrix is square. 89 func isSquare(r, c int) bool { 90 return r == c 91 } 92 93 // sameAnswerFloat returns whether the two inputs are both NaN or are equal. 94 func sameAnswerFloat(a, b interface{}) bool { 95 if math.IsNaN(a.(float64)) { 96 return math.IsNaN(b.(float64)) 97 } 98 return a.(float64) == b.(float64) 99 } 100 101 // sameAnswerFloatApproxTol returns a function that determines whether its two 102 // inputs are both NaN or within tol of each other. 103 func sameAnswerFloatApproxTol(tol float64) func(a, b interface{}) bool { 104 return func(a, b interface{}) bool { 105 if math.IsNaN(a.(float64)) { 106 return math.IsNaN(b.(float64)) 107 } 108 return scalar.EqualWithinAbsOrRel(a.(float64), b.(float64), tol, tol) 109 } 110 } 111 112 func sameAnswerF64SliceOfSlice(a, b interface{}) bool { 113 for i, v := range a.([][]float64) { 114 if same := floats.Same(v, b.([][]float64)[i]); !same { 115 return false 116 } 117 } 118 return true 119 } 120 121 // sameAnswerBool returns whether the two inputs have the same value. 122 func sameAnswerBool(a, b interface{}) bool { 123 return a.(bool) == b.(bool) 124 } 125 126 // isAnyType returns true for all Matrix types. 127 func isAnyType(Matrix) bool { 128 return true 129 } 130 131 // legalTypesAll returns true for all Matrix types. 132 func legalTypesAll(a, b Matrix) bool { 133 return true 134 } 135 136 // legalTypeSym returns whether a is a Symmetric. 137 func legalTypeSym(a Matrix) bool { 138 _, ok := a.(Symmetric) 139 return ok 140 } 141 142 // legalTypeTri returns whether a is a Triangular. 143 func legalTypeTri(a Matrix) bool { 144 _, ok := a.(Triangular) 145 return ok 146 } 147 148 // legalTypeTriLower returns whether a is a Triangular with kind == Lower. 149 func legalTypeTriLower(a Matrix) bool { 150 t, ok := a.(Triangular) 151 if !ok { 152 return false 153 } 154 _, kind := t.Triangle() 155 return kind == Lower 156 } 157 158 // legalTypeTriUpper returns whether a is a Triangular with kind == Upper. 159 func legalTypeTriUpper(a Matrix) bool { 160 t, ok := a.(Triangular) 161 if !ok { 162 return false 163 } 164 _, kind := t.Triangle() 165 return kind == Upper 166 } 167 168 // legalTypesSym returns whether both input arguments are Symmetric. 169 func legalTypesSym(a, b Matrix) bool { 170 if _, ok := a.(Symmetric); !ok { 171 return false 172 } 173 if _, ok := b.(Symmetric); !ok { 174 return false 175 } 176 return true 177 } 178 179 // legalTypeVector returns whether v is a Vector. 180 func legalTypeVector(v Matrix) bool { 181 _, ok := v.(Vector) 182 return ok 183 } 184 185 // legalTypeVec returns whether v is a *VecDense. 186 func legalTypeVecDense(v Matrix) bool { 187 _, ok := v.(*VecDense) 188 return ok 189 } 190 191 // legalTypesVectorVector returns whether both inputs are Vector 192 func legalTypesVectorVector(a, b Matrix) bool { 193 if _, ok := a.(Vector); !ok { 194 return false 195 } 196 if _, ok := b.(Vector); !ok { 197 return false 198 } 199 return true 200 } 201 202 // legalTypesVecDenseVecDense returns whether both inputs are *VecDense. 203 func legalTypesVecDenseVecDense(a, b Matrix) bool { 204 if _, ok := a.(*VecDense); !ok { 205 return false 206 } 207 if _, ok := b.(*VecDense); !ok { 208 return false 209 } 210 return true 211 } 212 213 // legalTypesMatrixVector returns whether the first input is an arbitrary Matrix 214 // and the second input is a Vector. 215 func legalTypesMatrixVector(a, b Matrix) bool { 216 _, ok := b.(Vector) 217 return ok 218 } 219 220 // legalTypesMatrixVecDense returns whether the first input is an arbitrary Matrix 221 // and the second input is a *VecDense. 222 func legalTypesMatrixVecDense(a, b Matrix) bool { 223 _, ok := b.(*VecDense) 224 return ok 225 } 226 227 // legalDims returns whether {m,n} is a valid dimension of the given matrix type. 228 func legalDims(a Matrix, m, n int) bool { 229 switch t := a.(type) { 230 default: 231 panic("legal dims type not coded") 232 case Untransposer: 233 return legalDims(t.Untranspose(), n, m) 234 case *Dense, *basicMatrix, *BandDense, *basicBanded: 235 if m < 0 || n < 0 { 236 return false 237 } 238 return true 239 case *SymDense, *TriDense, *basicSymmetric, *basicTriangular, 240 *SymBandDense, *basicSymBanded, *TriBandDense, *basicTriBanded, 241 *basicDiagonal, *DiagDense: 242 if m < 0 || n < 0 || m != n { 243 return false 244 } 245 return true 246 case *VecDense, *basicVector: 247 if m < 0 || n < 0 { 248 return false 249 } 250 return n == 1 251 } 252 } 253 254 // returnAs returns the matrix a with the type of t. Used for making a concrete 255 // type and changing to the basic form. 256 func returnAs(a, t Matrix) Matrix { 257 switch mat := a.(type) { 258 default: 259 panic("unknown type for a") 260 case *Dense: 261 switch t.(type) { 262 default: 263 panic("bad type") 264 case *Dense: 265 return mat 266 case *basicMatrix: 267 return asBasicMatrix(mat) 268 } 269 case *SymDense: 270 switch t.(type) { 271 default: 272 panic("bad type") 273 case *SymDense: 274 return mat 275 case *basicSymmetric: 276 return asBasicSymmetric(mat) 277 } 278 case *TriDense: 279 switch t.(type) { 280 default: 281 panic("bad type") 282 case *TriDense: 283 return mat 284 case *basicTriangular: 285 return asBasicTriangular(mat) 286 } 287 case *BandDense: 288 switch t.(type) { 289 default: 290 panic("bad type") 291 case *BandDense: 292 return mat 293 case *basicBanded: 294 return asBasicBanded(mat) 295 } 296 case *SymBandDense: 297 switch t.(type) { 298 default: 299 panic("bad type") 300 case *SymBandDense: 301 return mat 302 case *basicSymBanded: 303 return asBasicSymBanded(mat) 304 } 305 case *TriBandDense: 306 switch t.(type) { 307 default: 308 panic("bad type") 309 case *TriBandDense: 310 return mat 311 case *basicTriBanded: 312 return asBasicTriBanded(mat) 313 } 314 case *DiagDense: 315 switch t.(type) { 316 default: 317 panic("bad type") 318 case *DiagDense: 319 return mat 320 case *basicDiagonal: 321 return asBasicDiagonal(mat) 322 } 323 } 324 } 325 326 // retranspose returns the matrix m inside an Untransposer of the type 327 // of a. 328 func retranspose(a, m Matrix) Matrix { 329 switch a.(type) { 330 case TransposeTriBand: 331 return TransposeTriBand{m.(TriBanded)} 332 case TransposeBand: 333 return TransposeBand{m.(Banded)} 334 case TransposeTri: 335 return TransposeTri{m.(Triangular)} 336 case Transpose: 337 return Transpose{m} 338 case Untransposer: 339 panic("unknown transposer type") 340 default: 341 panic("a is not an untransposer") 342 } 343 } 344 345 // makeRandOf returns a new randomly filled m×n matrix of the underlying matrix type. 346 func makeRandOf(a Matrix, m, n int, src rand.Source) Matrix { 347 rnd := rand.New(src) 348 var rMatrix Matrix 349 switch t := a.(type) { 350 default: 351 panic("unknown type for make rand of") 352 case Untransposer: 353 rMatrix = retranspose(a, makeRandOf(t.Untranspose(), n, m, src)) 354 case *Dense, *basicMatrix: 355 var mat = &Dense{} 356 if m != 0 && n != 0 { 357 mat = NewDense(m, n, nil) 358 } 359 for i := 0; i < m; i++ { 360 for j := 0; j < n; j++ { 361 mat.Set(i, j, rnd.NormFloat64()) 362 } 363 } 364 rMatrix = returnAs(mat, t) 365 case *VecDense: 366 if m == 0 && n == 0 { 367 return &VecDense{} 368 } 369 if n != 1 { 370 panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n)) 371 } 372 length := m 373 inc := 1 374 if t.mat.Inc != 0 { 375 inc = t.mat.Inc 376 } 377 mat := &VecDense{ 378 mat: blas64.Vector{ 379 N: length, 380 Inc: inc, 381 Data: make([]float64, inc*(length-1)+1), 382 }, 383 } 384 for i := 0; i < length; i++ { 385 mat.SetVec(i, rnd.NormFloat64()) 386 } 387 return mat 388 case *basicVector: 389 if n != 1 { 390 panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n)) 391 } 392 if m == 0 { 393 return &basicVector{} 394 } 395 mat := NewVecDense(m, nil) 396 for i := 0; i < m; i++ { 397 mat.SetVec(i, rnd.NormFloat64()) 398 } 399 return asBasicVector(mat) 400 case *SymDense, *basicSymmetric: 401 if m != n { 402 panic("bad size") 403 } 404 mat := &SymDense{} 405 if n != 0 { 406 mat = NewSymDense(n, nil) 407 } 408 for i := 0; i < m; i++ { 409 for j := i; j < n; j++ { 410 mat.SetSym(i, j, rnd.NormFloat64()) 411 } 412 } 413 rMatrix = returnAs(mat, t) 414 case *TriDense, *basicTriangular: 415 if m != n { 416 panic("bad size") 417 } 418 419 // This is necessary because we are making 420 // a triangle from the zero value, which 421 // always returns upper as true. 422 var triKind TriKind 423 switch t := t.(type) { 424 case *TriDense: 425 triKind = t.triKind() 426 case *basicTriangular: 427 triKind = (*TriDense)(t).triKind() 428 } 429 430 if n == 0 { 431 uplo := blas.Upper 432 if triKind == Lower { 433 uplo = blas.Lower 434 } 435 return returnAs(&TriDense{mat: blas64.Triangular{Uplo: uplo}}, t) 436 } 437 438 mat := NewTriDense(n, triKind, nil) 439 if triKind == Upper { 440 for i := 0; i < m; i++ { 441 for j := i; j < n; j++ { 442 mat.SetTri(i, j, rnd.NormFloat64()) 443 } 444 } 445 } else { 446 for i := 0; i < m; i++ { 447 for j := 0; j <= i; j++ { 448 mat.SetTri(i, j, rnd.NormFloat64()) 449 } 450 } 451 } 452 rMatrix = returnAs(mat, t) 453 case *BandDense, *basicBanded: 454 var kl, ku int 455 switch t := t.(type) { 456 case *BandDense: 457 kl = t.mat.KL 458 ku = t.mat.KU 459 case *basicBanded: 460 ku = (*BandDense)(t).mat.KU 461 kl = (*BandDense)(t).mat.KL 462 } 463 ku = min(ku, n-1) 464 kl = min(kl, m-1) 465 data := make([]float64, min(m, n+kl)*(kl+ku+1)) 466 for i := range data { 467 data[i] = rnd.NormFloat64() 468 } 469 mat := NewBandDense(m, n, kl, ku, data) 470 rMatrix = returnAs(mat, t) 471 case *SymBandDense, *basicSymBanded: 472 if m != n { 473 panic("bad size") 474 } 475 var k int 476 switch t := t.(type) { 477 case *SymBandDense: 478 k = t.mat.K 479 case *basicSymBanded: 480 k = (*SymBandDense)(t).mat.K 481 } 482 k = min(k, m-1) // Special case for small sizes. 483 data := make([]float64, m*(k+1)) 484 for i := range data { 485 data[i] = rnd.NormFloat64() 486 } 487 mat := NewSymBandDense(n, k, data) 488 rMatrix = returnAs(mat, t) 489 case *TriBandDense, *basicTriBanded: 490 if m != n { 491 panic("bad size") 492 } 493 var k int 494 var triKind TriKind 495 switch t := t.(type) { 496 case *TriBandDense: 497 k = t.mat.K 498 triKind = t.triKind() 499 case *basicTriBanded: 500 k = (*TriBandDense)(t).mat.K 501 triKind = (*TriBandDense)(t).triKind() 502 } 503 k = min(k, m-1) // Special case for small sizes. 504 data := make([]float64, m*(k+1)) 505 for i := range data { 506 data[i] = rnd.NormFloat64() 507 } 508 mat := NewTriBandDense(n, k, triKind, data) 509 rMatrix = returnAs(mat, t) 510 case *DiagDense, *basicDiagonal: 511 if m != n { 512 panic("bad size") 513 } 514 var inc int 515 switch t := t.(type) { 516 case *DiagDense: 517 inc = t.mat.Inc 518 case *basicDiagonal: 519 inc = (*DiagDense)(t).mat.Inc 520 } 521 if inc == 0 { 522 inc = 1 523 } 524 mat := &DiagDense{ 525 mat: blas64.Vector{ 526 N: n, 527 Inc: inc, 528 Data: make([]float64, inc*(n-1)+1), 529 }, 530 } 531 for i := 0; i < n; i++ { 532 mat.SetDiag(i, rnd.Float64()) 533 } 534 rMatrix = returnAs(mat, t) 535 } 536 if mr, mc := rMatrix.Dims(); mr != m || mc != n { 537 panic(fmt.Sprintf("makeRandOf for %T returns wrong size: %d×%d != %d×%d", a, m, n, mr, mc)) 538 } 539 return rMatrix 540 } 541 542 // makeNaNOf returns a new m×n matrix of the underlying matrix type filled with NaN values. 543 func makeNaNOf(a Matrix, m, n int) Matrix { 544 var rMatrix Matrix 545 switch t := a.(type) { 546 default: 547 panic("unknown type for makeNaNOf") 548 case Untransposer: 549 rMatrix = retranspose(a, makeNaNOf(t.Untranspose(), n, m)) 550 case *Dense, *basicMatrix: 551 var mat = &Dense{} 552 if m != 0 && n != 0 { 553 mat = NewDense(m, n, nil) 554 } 555 for i := 0; i < m; i++ { 556 for j := 0; j < n; j++ { 557 mat.Set(i, j, math.NaN()) 558 } 559 } 560 rMatrix = returnAs(mat, t) 561 case *VecDense: 562 if m == 0 && n == 0 { 563 return &VecDense{} 564 } 565 if n != 1 { 566 panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n)) 567 } 568 length := m 569 inc := 1 570 if t.mat.Inc != 0 { 571 inc = t.mat.Inc 572 } 573 mat := &VecDense{ 574 mat: blas64.Vector{ 575 N: length, 576 Inc: inc, 577 Data: make([]float64, inc*(length-1)+1), 578 }, 579 } 580 for i := 0; i < length; i++ { 581 mat.SetVec(i, math.NaN()) 582 } 583 return mat 584 case *basicVector: 585 if n != 1 { 586 panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n)) 587 } 588 if m == 0 { 589 return &basicVector{} 590 } 591 mat := NewVecDense(m, nil) 592 for i := 0; i < m; i++ { 593 mat.SetVec(i, math.NaN()) 594 } 595 return asBasicVector(mat) 596 case *SymDense, *basicSymmetric: 597 if m != n { 598 panic("bad size") 599 } 600 mat := &SymDense{} 601 if n != 0 { 602 mat = NewSymDense(n, nil) 603 } 604 for i := 0; i < m; i++ { 605 for j := i; j < n; j++ { 606 mat.SetSym(i, j, math.NaN()) 607 } 608 } 609 rMatrix = returnAs(mat, t) 610 case *TriDense, *basicTriangular: 611 if m != n { 612 panic("bad size") 613 } 614 615 // This is necessary because we are making 616 // a triangle from the zero value, which 617 // always returns upper as true. 618 var triKind TriKind 619 switch t := t.(type) { 620 case *TriDense: 621 triKind = t.triKind() 622 case *basicTriangular: 623 triKind = (*TriDense)(t).triKind() 624 } 625 626 if n == 0 { 627 uplo := blas.Upper 628 if triKind == Lower { 629 uplo = blas.Lower 630 } 631 return returnAs(&TriDense{mat: blas64.Triangular{Uplo: uplo}}, t) 632 } 633 634 mat := NewTriDense(n, triKind, nil) 635 if triKind == Upper { 636 for i := 0; i < m; i++ { 637 for j := i; j < n; j++ { 638 mat.SetTri(i, j, math.NaN()) 639 } 640 } 641 } else { 642 for i := 0; i < m; i++ { 643 for j := 0; j <= i; j++ { 644 mat.SetTri(i, j, math.NaN()) 645 } 646 } 647 } 648 rMatrix = returnAs(mat, t) 649 case *BandDense, *basicBanded: 650 var kl, ku int 651 switch t := t.(type) { 652 case *BandDense: 653 kl = t.mat.KL 654 ku = t.mat.KU 655 case *basicBanded: 656 ku = (*BandDense)(t).mat.KU 657 kl = (*BandDense)(t).mat.KL 658 } 659 ku = min(ku, n-1) 660 kl = min(kl, m-1) 661 data := make([]float64, min(m, n+kl)*(kl+ku+1)) 662 for i := range data { 663 data[i] = math.NaN() 664 } 665 mat := NewBandDense(m, n, kl, ku, data) 666 rMatrix = returnAs(mat, t) 667 case *SymBandDense, *basicSymBanded: 668 if m != n { 669 panic("bad size") 670 } 671 var k int 672 switch t := t.(type) { 673 case *SymBandDense: 674 k = t.mat.K 675 case *basicSymBanded: 676 k = (*SymBandDense)(t).mat.K 677 } 678 k = min(k, m-1) // Special case for small sizes. 679 data := make([]float64, m*(k+1)) 680 for i := range data { 681 data[i] = math.NaN() 682 } 683 mat := NewSymBandDense(n, k, data) 684 rMatrix = returnAs(mat, t) 685 case *TriBandDense, *basicTriBanded: 686 if m != n { 687 panic("bad size") 688 } 689 var k int 690 var triKind TriKind 691 switch t := t.(type) { 692 case *TriBandDense: 693 k = t.mat.K 694 triKind = t.triKind() 695 case *basicTriBanded: 696 k = (*TriBandDense)(t).mat.K 697 triKind = (*TriBandDense)(t).triKind() 698 } 699 k = min(k, m-1) // Special case for small sizes. 700 data := make([]float64, m*(k+1)) 701 for i := range data { 702 data[i] = math.NaN() 703 } 704 mat := NewTriBandDense(n, k, triKind, data) 705 rMatrix = returnAs(mat, t) 706 case *DiagDense, *basicDiagonal: 707 if m != n { 708 panic("bad size") 709 } 710 var inc int 711 switch t := t.(type) { 712 case *DiagDense: 713 inc = t.mat.Inc 714 case *basicDiagonal: 715 inc = (*DiagDense)(t).mat.Inc 716 } 717 if inc == 0 { 718 inc = 1 719 } 720 mat := &DiagDense{ 721 mat: blas64.Vector{ 722 N: n, 723 Inc: inc, 724 Data: make([]float64, inc*(n-1)+1), 725 }, 726 } 727 for i := 0; i < n; i++ { 728 mat.SetDiag(i, math.NaN()) 729 } 730 rMatrix = returnAs(mat, t) 731 } 732 if mr, mc := rMatrix.Dims(); mr != m || mc != n { 733 panic(fmt.Sprintf("makeNaNOf for %T returns wrong size: %d×%d != %d×%d", a, m, n, mr, mc)) 734 } 735 return rMatrix 736 } 737 738 // makeCopyOf returns a copy of the matrix. 739 func makeCopyOf(a Matrix) Matrix { 740 switch t := a.(type) { 741 default: 742 panic("unknown type in makeCopyOf") 743 case Untransposer: 744 return retranspose(a, makeCopyOf(t.Untranspose())) 745 case *Dense, *basicMatrix: 746 var m Dense 747 m.CloneFrom(a) 748 return returnAs(&m, t) 749 case *SymDense, *basicSymmetric: 750 n := t.(Symmetric).Symmetric() 751 m := NewSymDense(n, nil) 752 m.CopySym(t.(Symmetric)) 753 return returnAs(m, t) 754 case *TriDense, *basicTriangular: 755 n, upper := t.(Triangular).Triangle() 756 m := NewTriDense(n, upper, nil) 757 if upper { 758 for i := 0; i < n; i++ { 759 for j := i; j < n; j++ { 760 m.SetTri(i, j, t.At(i, j)) 761 } 762 } 763 } else { 764 for i := 0; i < n; i++ { 765 for j := 0; j <= i; j++ { 766 m.SetTri(i, j, t.At(i, j)) 767 } 768 } 769 } 770 return returnAs(m, t) 771 case *BandDense, *basicBanded: 772 var band *BandDense 773 switch s := t.(type) { 774 case *BandDense: 775 band = s 776 case *basicBanded: 777 band = (*BandDense)(s) 778 } 779 m := &BandDense{ 780 mat: blas64.Band{ 781 Rows: band.mat.Rows, 782 Cols: band.mat.Cols, 783 KL: band.mat.KL, 784 KU: band.mat.KU, 785 Data: make([]float64, len(band.mat.Data)), 786 Stride: band.mat.Stride, 787 }, 788 } 789 copy(m.mat.Data, band.mat.Data) 790 return returnAs(m, t) 791 case *SymBandDense, *basicSymBanded: 792 var sym *SymBandDense 793 switch s := t.(type) { 794 case *SymBandDense: 795 sym = s 796 case *basicSymBanded: 797 sym = (*SymBandDense)(s) 798 } 799 m := &SymBandDense{ 800 mat: blas64.SymmetricBand{ 801 Uplo: blas.Upper, 802 N: sym.mat.N, 803 K: sym.mat.K, 804 Data: make([]float64, len(sym.mat.Data)), 805 Stride: sym.mat.Stride, 806 }, 807 } 808 copy(m.mat.Data, sym.mat.Data) 809 return returnAs(m, t) 810 case *TriBandDense, *basicTriBanded: 811 var tri *TriBandDense 812 switch s := t.(type) { 813 case *TriBandDense: 814 tri = s 815 case *basicTriBanded: 816 tri = (*TriBandDense)(s) 817 } 818 m := &TriBandDense{ 819 mat: blas64.TriangularBand{ 820 Uplo: tri.mat.Uplo, 821 Diag: tri.mat.Diag, 822 N: tri.mat.N, 823 K: tri.mat.K, 824 Data: make([]float64, len(tri.mat.Data)), 825 Stride: tri.mat.Stride, 826 }, 827 } 828 copy(m.mat.Data, tri.mat.Data) 829 return returnAs(m, t) 830 case *VecDense: 831 var m VecDense 832 m.CloneFromVec(t) 833 return &m 834 case *basicVector: 835 var m VecDense 836 m.CloneFromVec(t) 837 return asBasicVector(&m) 838 case *DiagDense, *basicDiagonal: 839 var diag *DiagDense 840 switch s := t.(type) { 841 case *DiagDense: 842 diag = s 843 case *basicDiagonal: 844 diag = (*DiagDense)(s) 845 } 846 d := &DiagDense{ 847 mat: blas64.Vector{N: diag.mat.N, Inc: diag.mat.Inc, Data: make([]float64, len(diag.mat.Data))}, 848 } 849 copy(d.mat.Data, diag.mat.Data) 850 return returnAs(d, t) 851 } 852 } 853 854 // sameType returns true if a and b have the same underlying type. 855 func sameType(a, b Matrix) bool { 856 return reflect.ValueOf(a).Type() == reflect.ValueOf(b).Type() 857 } 858 859 // maybeSame returns true if the two matrices could be represented by the same 860 // pointer. 861 func maybeSame(receiver, a Matrix) bool { 862 rr, rc := receiver.Dims() 863 u, trans := a.(Untransposer) 864 if trans { 865 a = u.Untranspose() 866 } 867 if !sameType(receiver, a) { 868 return false 869 } 870 ar, ac := a.Dims() 871 if rr != ar || rc != ac { 872 return false 873 } 874 if _, ok := a.(Triangular); ok { 875 // They are both triangular types. The TriType needs to match 876 _, aKind := a.(Triangular).Triangle() 877 _, rKind := receiver.(Triangular).Triangle() 878 if aKind != rKind { 879 return false 880 } 881 } 882 return true 883 } 884 885 // equalApprox returns whether the elements of a and b are the same to within 886 // the tolerance. If ignoreNaN is true the test is relaxed such that NaN == NaN. 887 func equalApprox(a, b Matrix, tol float64, ignoreNaN bool) bool { 888 ar, ac := a.Dims() 889 br, bc := b.Dims() 890 if ar != br { 891 return false 892 } 893 if ac != bc { 894 return false 895 } 896 for i := 0; i < ar; i++ { 897 for j := 0; j < ac; j++ { 898 if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), tol, tol) { 899 if ignoreNaN && math.IsNaN(a.At(i, j)) && math.IsNaN(b.At(i, j)) { 900 continue 901 } 902 return false 903 } 904 } 905 } 906 return true 907 } 908 909 // equal returns true if the matrices have equal entries. 910 func equal(a, b Matrix) bool { 911 ar, ac := a.Dims() 912 br, bc := b.Dims() 913 if ar != br { 914 return false 915 } 916 if ac != bc { 917 return false 918 } 919 for i := 0; i < ar; i++ { 920 for j := 0; j < ac; j++ { 921 if a.At(i, j) != b.At(i, j) { 922 return false 923 } 924 } 925 } 926 return true 927 } 928 929 // isDiagonal returns whether a is a diagonal matrix. 930 func isDiagonal(a Matrix) bool { 931 r, c := a.Dims() 932 for i := 0; i < r; i++ { 933 for j := 0; j < c; j++ { 934 if a.At(i, j) != 0 && i != j { 935 return false 936 } 937 } 938 } 939 return true 940 } 941 942 // equalDiagonal returns whether a and b are equal on the diagonal. 943 func equalDiagonal(a, b Matrix) bool { 944 ar, ac := a.Dims() 945 br, bc := a.Dims() 946 if min(ar, ac) != min(br, bc) { 947 return false 948 } 949 for i := 0; i < min(ar, ac); i++ { 950 if a.At(i, i) != b.At(i, i) { 951 return false 952 } 953 } 954 return true 955 } 956 957 // underlyingData extracts the underlying data of the matrix a. 958 func underlyingData(a Matrix) []float64 { 959 switch t := a.(type) { 960 default: 961 panic("matrix type not implemented for extracting underlying data") 962 case Untransposer: 963 return underlyingData(t.Untranspose()) 964 case *Dense: 965 return t.mat.Data 966 case *SymDense: 967 return t.mat.Data 968 case *TriDense: 969 return t.mat.Data 970 case *VecDense: 971 return t.mat.Data 972 } 973 } 974 975 // testMatrices is a list of matrix types to test. 976 // This test relies on the fact that the implementations of Triangle do not 977 // corrupt the value of Uplo when they are empty. This test will fail 978 // if that changes (and some mechanism will need to be used to force the 979 // correct TriKind to be read). 980 var testMatrices = []Matrix{ 981 &Dense{}, 982 &basicMatrix{}, 983 Transpose{&Dense{}}, 984 985 &VecDense{mat: blas64.Vector{Inc: 1}}, 986 &VecDense{mat: blas64.Vector{Inc: 10}}, 987 &basicVector{}, 988 Transpose{&VecDense{mat: blas64.Vector{Inc: 1}}}, 989 Transpose{&VecDense{mat: blas64.Vector{Inc: 10}}}, 990 Transpose{&basicVector{}}, 991 992 &BandDense{mat: blas64.Band{KL: 2, KU: 1}}, 993 &BandDense{mat: blas64.Band{KL: 1, KU: 2}}, 994 Transpose{&BandDense{mat: blas64.Band{KL: 2, KU: 1}}}, 995 Transpose{&BandDense{mat: blas64.Band{KL: 1, KU: 2}}}, 996 TransposeBand{&BandDense{mat: blas64.Band{KL: 2, KU: 1}}}, 997 TransposeBand{&BandDense{mat: blas64.Band{KL: 1, KU: 2}}}, 998 999 &SymDense{}, 1000 &basicSymmetric{}, 1001 Transpose{&basicSymmetric{}}, 1002 1003 &TriDense{mat: blas64.Triangular{Uplo: blas.Upper}}, 1004 &TriDense{mat: blas64.Triangular{Uplo: blas.Lower}}, 1005 &basicTriangular{mat: blas64.Triangular{Uplo: blas.Upper}}, 1006 &basicTriangular{mat: blas64.Triangular{Uplo: blas.Lower}}, 1007 Transpose{&TriDense{mat: blas64.Triangular{Uplo: blas.Upper}}}, 1008 Transpose{&TriDense{mat: blas64.Triangular{Uplo: blas.Lower}}}, 1009 TransposeTri{&TriDense{mat: blas64.Triangular{Uplo: blas.Upper}}}, 1010 TransposeTri{&TriDense{mat: blas64.Triangular{Uplo: blas.Lower}}}, 1011 Transpose{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Upper}}}, 1012 Transpose{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Lower}}}, 1013 TransposeTri{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Upper}}}, 1014 TransposeTri{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Lower}}}, 1015 1016 &SymBandDense{}, 1017 &basicSymBanded{}, 1018 Transpose{&basicSymBanded{}}, 1019 1020 &SymBandDense{mat: blas64.SymmetricBand{K: 2}}, 1021 &basicSymBanded{mat: blas64.SymmetricBand{K: 2}}, 1022 Transpose{&basicSymBanded{mat: blas64.SymmetricBand{K: 2}}}, 1023 TransposeBand{&basicSymBanded{mat: blas64.SymmetricBand{K: 2}}}, 1024 1025 &TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}, 1026 &TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}, 1027 &basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}, 1028 &basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}, 1029 Transpose{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1030 Transpose{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1031 Transpose{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1032 Transpose{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1033 TransposeTri{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1034 TransposeTri{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1035 TransposeTri{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1036 TransposeTri{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1037 TransposeBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1038 TransposeBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1039 TransposeBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1040 TransposeBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1041 TransposeTriBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1042 TransposeTriBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1043 TransposeTriBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}}, 1044 TransposeTriBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}}, 1045 1046 &DiagDense{}, 1047 &DiagDense{mat: blas64.Vector{Inc: 10}}, 1048 Transpose{&DiagDense{}}, 1049 Transpose{&DiagDense{mat: blas64.Vector{Inc: 10}}}, 1050 TransposeTri{&DiagDense{}}, 1051 TransposeTri{&DiagDense{mat: blas64.Vector{Inc: 10}}}, 1052 TransposeBand{&DiagDense{}}, 1053 TransposeBand{&DiagDense{mat: blas64.Vector{Inc: 10}}}, 1054 TransposeTriBand{&DiagDense{}}, 1055 TransposeTriBand{&DiagDense{mat: blas64.Vector{Inc: 10}}}, 1056 &basicDiagonal{}, 1057 Transpose{&basicDiagonal{}}, 1058 TransposeTri{&basicDiagonal{}}, 1059 TransposeBand{&basicDiagonal{}}, 1060 TransposeTriBand{&basicDiagonal{}}, 1061 } 1062 1063 var sizes = []struct { 1064 ar, ac int 1065 }{ 1066 {1, 1}, 1067 {1, 3}, 1068 {3, 1}, 1069 1070 {6, 6}, 1071 {6, 11}, 1072 {11, 6}, 1073 } 1074 1075 func testOneInputFunc(t *testing.T, 1076 // name is the name of the function being tested. 1077 name string, 1078 1079 // f is the function being tested. 1080 f func(a Matrix) interface{}, 1081 1082 // denseComparison performs the same operation, but using Dense matrices for 1083 // comparison. 1084 denseComparison func(a *Dense) interface{}, 1085 1086 // sameAnswer compares the result from two different evaluations of the function 1087 // and returns true if they are the same. The specific function being tested 1088 // determines the definition of "same". It may mean identical or it may mean 1089 // approximately equal. 1090 sameAnswer func(a, b interface{}) bool, 1091 1092 // legalType returns true if the type of the input is a legal type for the 1093 // input of the function. 1094 legalType func(a Matrix) bool, 1095 1096 // legalSize returns true if the size is valid for the function. 1097 legalSize func(r, c int) bool, 1098 ) { 1099 src := rand.NewSource(1) 1100 for _, aMat := range testMatrices { 1101 for _, test := range sizes { 1102 // Skip the test if the argument would not be assignable to the 1103 // method's corresponding input parameter or it is not possible 1104 // to construct an argument of the requested size. 1105 if !legalType(aMat) { 1106 continue 1107 } 1108 if !legalDims(aMat, test.ar, test.ac) { 1109 continue 1110 } 1111 a := makeRandOf(aMat, test.ar, test.ac, src) 1112 1113 // Compute the true answer if the sizes are legal. 1114 dimsOK := legalSize(test.ar, test.ac) 1115 var want interface{} 1116 if dimsOK { 1117 var aDense Dense 1118 aDense.CloneFrom(a) 1119 want = denseComparison(&aDense) 1120 } 1121 aCopy := makeCopyOf(a) 1122 // Test the method for a zero-value of the receiver. 1123 aType, aTrans := untranspose(a) 1124 errStr := fmt.Sprintf("%v(%T), size: %#v, atrans %t", name, aType, test, aTrans) 1125 var got interface{} 1126 panicked, err := panics(func() { got = f(a) }) 1127 if !dimsOK && !panicked { 1128 t.Errorf("Did not panic with illegal size: %s", errStr) 1129 continue 1130 } 1131 if dimsOK && panicked { 1132 t.Errorf("Panicked with legal size: %s: %v", errStr, err) 1133 continue 1134 } 1135 if !equal(a, aCopy) { 1136 t.Errorf("First input argument changed in call: %s", errStr) 1137 } 1138 if !dimsOK { 1139 continue 1140 } 1141 if !sameAnswer(want, got) { 1142 t.Errorf("Answer mismatch: %s", errStr) 1143 } 1144 } 1145 } 1146 } 1147 1148 var sizePairs = []struct { 1149 ar, ac, br, bc int 1150 }{ 1151 {1, 1, 1, 1}, 1152 {6, 6, 6, 6}, 1153 {7, 7, 7, 7}, 1154 1155 {1, 1, 1, 5}, 1156 {1, 1, 5, 1}, 1157 {1, 5, 1, 1}, 1158 {5, 1, 1, 1}, 1159 1160 {5, 5, 5, 1}, 1161 {5, 5, 1, 5}, 1162 {5, 1, 5, 5}, 1163 {1, 5, 5, 5}, 1164 1165 {6, 6, 6, 11}, 1166 {6, 6, 11, 6}, 1167 {6, 11, 6, 6}, 1168 {11, 6, 6, 6}, 1169 {11, 11, 11, 6}, 1170 {11, 11, 6, 11}, 1171 {11, 6, 11, 11}, 1172 {6, 11, 11, 11}, 1173 1174 {1, 1, 5, 5}, 1175 {1, 5, 1, 5}, 1176 {1, 5, 5, 1}, 1177 {5, 1, 1, 5}, 1178 {5, 1, 5, 1}, 1179 {5, 5, 1, 1}, 1180 {6, 6, 11, 11}, 1181 {6, 11, 6, 11}, 1182 {6, 11, 11, 6}, 1183 {11, 6, 6, 11}, 1184 {11, 6, 11, 6}, 1185 {11, 11, 6, 6}, 1186 1187 {1, 1, 17, 11}, 1188 {1, 1, 11, 17}, 1189 {1, 11, 1, 17}, 1190 {1, 17, 1, 11}, 1191 {1, 11, 17, 1}, 1192 {1, 17, 11, 1}, 1193 {11, 1, 1, 17}, 1194 {17, 1, 1, 11}, 1195 {11, 1, 17, 1}, 1196 {17, 1, 11, 1}, 1197 {11, 17, 1, 1}, 1198 {17, 11, 1, 1}, 1199 1200 {6, 6, 1, 11}, 1201 {6, 6, 11, 1}, 1202 {6, 11, 6, 1}, 1203 {6, 1, 6, 11}, 1204 {6, 11, 1, 6}, 1205 {6, 1, 11, 6}, 1206 {11, 6, 6, 1}, 1207 {1, 6, 6, 11}, 1208 {11, 6, 1, 6}, 1209 {1, 6, 11, 6}, 1210 {11, 1, 6, 6}, 1211 {1, 11, 6, 6}, 1212 1213 {6, 6, 17, 1}, 1214 {6, 6, 1, 17}, 1215 {6, 1, 6, 17}, 1216 {6, 17, 6, 1}, 1217 {6, 1, 17, 6}, 1218 {6, 17, 1, 6}, 1219 {1, 6, 6, 17}, 1220 {17, 6, 6, 1}, 1221 {1, 6, 17, 6}, 1222 {17, 6, 1, 6}, 1223 {1, 17, 6, 6}, 1224 {17, 1, 6, 6}, 1225 1226 {6, 6, 17, 11}, 1227 {6, 6, 11, 17}, 1228 {6, 11, 6, 17}, 1229 {6, 17, 6, 11}, 1230 {6, 11, 17, 6}, 1231 {6, 17, 11, 6}, 1232 {11, 6, 6, 17}, 1233 {17, 6, 6, 11}, 1234 {11, 6, 17, 6}, 1235 {17, 6, 11, 6}, 1236 {11, 17, 6, 6}, 1237 {17, 11, 6, 6}, 1238 } 1239 1240 func testTwoInputFunc(t *testing.T, 1241 // name is the name of the function being tested. 1242 name string, 1243 1244 // f is the function being tested. 1245 f func(a, b Matrix) interface{}, 1246 1247 // denseComparison performs the same operation, but using Dense matrices for 1248 // comparison. 1249 denseComparison func(a, b *Dense) interface{}, 1250 1251 // sameAnswer compares the result from two different evaluations of the function 1252 // and returns true if they are the same. The specific function being tested 1253 // determines the definition of "same". It may mean identical or it may mean 1254 // approximately equal. 1255 sameAnswer func(a, b interface{}) bool, 1256 1257 // legalType returns true if the types of the inputs are legal for the 1258 // input of the function. 1259 legalType func(a, b Matrix) bool, 1260 1261 // legalSize returns true if the sizes are valid for the function. 1262 legalSize func(ar, ac, br, bc int) bool, 1263 ) { 1264 src := rand.NewSource(1) 1265 for _, aMat := range testMatrices { 1266 for _, bMat := range testMatrices { 1267 // Loop over all of the size combinations (bigger, smaller, etc.). 1268 for _, test := range sizePairs { 1269 // Skip the test if the argument would not be assignable to the 1270 // method's corresponding input parameter or it is not possible 1271 // to construct an argument of the requested size. 1272 if !legalType(aMat, bMat) { 1273 continue 1274 } 1275 if !legalDims(aMat, test.ar, test.ac) { 1276 continue 1277 } 1278 if !legalDims(bMat, test.br, test.bc) { 1279 continue 1280 } 1281 a := makeRandOf(aMat, test.ar, test.ac, src) 1282 b := makeRandOf(bMat, test.br, test.bc, src) 1283 1284 // Compute the true answer if the sizes are legal. 1285 dimsOK := legalSize(test.ar, test.ac, test.br, test.bc) 1286 var want interface{} 1287 if dimsOK { 1288 var aDense, bDense Dense 1289 aDense.CloneFrom(a) 1290 bDense.CloneFrom(b) 1291 want = denseComparison(&aDense, &bDense) 1292 } 1293 aCopy := makeCopyOf(a) 1294 bCopy := makeCopyOf(b) 1295 // Test the method for a zero-value of the receiver. 1296 aType, aTrans := untranspose(a) 1297 bType, bTrans := untranspose(b) 1298 errStr := fmt.Sprintf("%v(%T, %T), size: %#v, atrans %t, btrans %t", name, aType, bType, test, aTrans, bTrans) 1299 var got interface{} 1300 panicked, err := panics(func() { got = f(a, b) }) 1301 if !dimsOK && !panicked { 1302 t.Errorf("Did not panic with illegal size: %s", errStr) 1303 continue 1304 } 1305 if dimsOK && panicked { 1306 t.Errorf("Panicked with legal size: %s: %v", errStr, err) 1307 continue 1308 } 1309 if !equal(a, aCopy) { 1310 t.Errorf("First input argument changed in call: %s", errStr) 1311 } 1312 if !equal(b, bCopy) { 1313 t.Errorf("First input argument changed in call: %s", errStr) 1314 } 1315 if !dimsOK { 1316 continue 1317 } 1318 if !sameAnswer(want, got) { 1319 t.Errorf("Answer mismatch: %s", errStr) 1320 } 1321 } 1322 } 1323 } 1324 } 1325 1326 // testOneInput tests a method that has one matrix input argument 1327 func testOneInput(t *testing.T, 1328 // name is the name of the method being tested. 1329 name string, 1330 1331 // receiver is a value of the receiver type. 1332 receiver Matrix, 1333 1334 // method is the generalized receiver.Method(a). 1335 method func(receiver, a Matrix), 1336 1337 // denseComparison performs the same operation as method, but with dense 1338 // matrices for comparison with the result. 1339 denseComparison func(receiver, a *Dense), 1340 1341 // legalTypes returns whether the concrete types in Matrix are valid for 1342 // the method. 1343 legalType func(a Matrix) bool, 1344 1345 // legalSize returns whether the matrix sizes are valid for the method. 1346 legalSize func(ar, ac int) bool, 1347 1348 // tol is the tolerance for equality when comparing method results. 1349 tol float64, 1350 ) { 1351 src := rand.NewSource(1) 1352 for _, aMat := range testMatrices { 1353 for _, test := range sizes { 1354 // Skip the test if the argument would not be assignable to the 1355 // method's corresponding input parameter or it is not possible 1356 // to construct an argument of the requested size. 1357 if !legalType(aMat) { 1358 continue 1359 } 1360 if !legalDims(aMat, test.ar, test.ac) { 1361 continue 1362 } 1363 a := makeRandOf(aMat, test.ar, test.ac, src) 1364 1365 // Compute the true answer if the sizes are legal. 1366 dimsOK := legalSize(test.ar, test.ac) 1367 var want Dense 1368 if dimsOK { 1369 var aDense Dense 1370 aDense.CloneFrom(a) 1371 denseComparison(&want, &aDense) 1372 } 1373 aCopy := makeCopyOf(a) 1374 1375 // Test the method for a zero-value of the receiver. 1376 aType, aTrans := untranspose(a) 1377 errStr := fmt.Sprintf("%T.%s(%T), size: %#v, atrans %v", receiver, name, aType, test, aTrans) 1378 empty := makeRandOf(receiver, 0, 0, src) 1379 panicked, err := panics(func() { method(empty, a) }) 1380 if !dimsOK && !panicked { 1381 t.Errorf("Did not panic with illegal size: %s", errStr) 1382 continue 1383 } 1384 if dimsOK && panicked { 1385 t.Errorf("Panicked with legal size: %s: %v", errStr, err) 1386 continue 1387 } 1388 if !equal(a, aCopy) { 1389 t.Errorf("First input argument changed in call: %s", errStr) 1390 } 1391 if !dimsOK { 1392 continue 1393 } 1394 if !equalApprox(empty, &want, tol, false) { 1395 t.Errorf("Answer mismatch with empty receiver: %s.\nGot:\n% v\nWant:\n% v\n", errStr, Formatted(empty), Formatted(&want)) 1396 continue 1397 } 1398 1399 // Test the method with a non-empty-value of the receiver. 1400 // The receiver has been overwritten in place so use its size 1401 // to construct a new random matrix. 1402 rr, rc := empty.Dims() 1403 neverEmpty := makeRandOf(receiver, rr, rc, src) 1404 panicked, message := panics(func() { method(neverEmpty, a) }) 1405 if panicked { 1406 t.Errorf("Panicked with non-empty receiver: %s: %s", errStr, message) 1407 } 1408 if !equalApprox(neverEmpty, &want, tol, false) { 1409 t.Errorf("Answer mismatch non-empty receiver: %s", errStr) 1410 } 1411 1412 // Test the method with a NaN-filled-value of the receiver. 1413 // The receiver has been overwritten in place so use its size 1414 // to construct a new NaN matrix. 1415 nanMatrix := makeNaNOf(receiver, rr, rc) 1416 panicked, message = panics(func() { method(nanMatrix, a) }) 1417 if panicked { 1418 t.Errorf("Panicked with NaN-filled receiver: %s: %s", errStr, message) 1419 } 1420 if !equalApprox(nanMatrix, &want, tol, false) { 1421 t.Errorf("Answer mismatch NaN-filled receiver: %s", errStr) 1422 } 1423 1424 // Test with an incorrectly sized matrix. 1425 switch receiver.(type) { 1426 default: 1427 panic("matrix type not coded for incorrect receiver size") 1428 case *Dense: 1429 wrongSize := makeRandOf(receiver, rr+1, rc, src) 1430 panicked, _ = panics(func() { method(wrongSize, a) }) 1431 if !panicked { 1432 t.Errorf("Did not panic with wrong number of rows: %s", errStr) 1433 } 1434 wrongSize = makeRandOf(receiver, rr, rc+1, src) 1435 panicked, _ = panics(func() { method(wrongSize, a) }) 1436 if !panicked { 1437 t.Errorf("Did not panic with wrong number of columns: %s", errStr) 1438 } 1439 case *TriDense, *SymDense: 1440 // Add to the square size. 1441 wrongSize := makeRandOf(receiver, rr+1, rc+1, src) 1442 panicked, _ = panics(func() { method(wrongSize, a) }) 1443 if !panicked { 1444 t.Errorf("Did not panic with wrong size: %s", errStr) 1445 } 1446 case *VecDense: 1447 // Add to the column length. 1448 wrongSize := makeRandOf(receiver, rr+1, rc, src) 1449 panicked, _ = panics(func() { method(wrongSize, a) }) 1450 if !panicked { 1451 t.Errorf("Did not panic with wrong number of rows: %s", errStr) 1452 } 1453 } 1454 1455 // The receiver and the input may share a matrix pointer 1456 // if the type and size of the receiver and one of the 1457 // arguments match. Test the method works properly 1458 // when this is the case. 1459 aMaybeSame := maybeSame(neverEmpty, a) 1460 if aMaybeSame { 1461 aSame := makeCopyOf(a) 1462 receiver = aSame 1463 u, ok := aSame.(Untransposer) 1464 if ok { 1465 receiver = u.Untranspose() 1466 } 1467 preData := underlyingData(receiver) 1468 panicked, err = panics(func() { method(receiver, aSame) }) 1469 if panicked { 1470 t.Errorf("Panics when a maybeSame: %s: %v", errStr, err) 1471 } else { 1472 if !equalApprox(receiver, &want, tol, false) { 1473 t.Errorf("Wrong answer when a maybeSame: %s", errStr) 1474 } 1475 postData := underlyingData(receiver) 1476 if !floats.Equal(preData, postData) { 1477 t.Errorf("Original data slice not modified when a maybeSame: %s", errStr) 1478 } 1479 } 1480 } 1481 } 1482 } 1483 } 1484 1485 // testTwoInput tests a method that has two input arguments. 1486 func testTwoInput(t *testing.T, 1487 // name is the name of the method being tested. 1488 name string, 1489 1490 // receiver is a value of the receiver type. 1491 receiver Matrix, 1492 1493 // method is the generalized receiver.Method(a, b). 1494 method func(receiver, a, b Matrix), 1495 1496 // denseComparison performs the same operation as method, but with dense 1497 // matrices for comparison with the result. 1498 denseComparison func(receiver, a, b *Dense), 1499 1500 // legalTypes returns whether the concrete types in Matrix are valid for 1501 // the method. 1502 legalTypes func(a, b Matrix) bool, 1503 1504 // legalSize returns whether the matrix sizes are valid for the method. 1505 legalSize func(ar, ac, br, bc int) bool, 1506 1507 // tol is the tolerance for equality when comparing method results. 1508 tol float64, 1509 ) { 1510 src := rand.NewSource(1) 1511 for _, aMat := range testMatrices { 1512 for _, bMat := range testMatrices { 1513 // Loop over all of the size combinations (bigger, smaller, etc.). 1514 for _, test := range sizePairs { 1515 // Skip the test if any argument would not be assignable to the 1516 // method's corresponding input parameter or it is not possible 1517 // to construct an argument of the requested size. 1518 if !legalTypes(aMat, bMat) { 1519 continue 1520 } 1521 if !legalDims(aMat, test.ar, test.ac) { 1522 continue 1523 } 1524 if !legalDims(bMat, test.br, test.bc) { 1525 continue 1526 } 1527 a := makeRandOf(aMat, test.ar, test.ac, src) 1528 b := makeRandOf(bMat, test.br, test.bc, src) 1529 1530 // Compute the true answer if the sizes are legal. 1531 dimsOK := legalSize(test.ar, test.ac, test.br, test.bc) 1532 var want Dense 1533 if dimsOK { 1534 var aDense, bDense Dense 1535 aDense.CloneFrom(a) 1536 bDense.CloneFrom(b) 1537 denseComparison(&want, &aDense, &bDense) 1538 } 1539 aCopy := makeCopyOf(a) 1540 bCopy := makeCopyOf(b) 1541 1542 // Test the method for a empty-value of the receiver. 1543 aType, aTrans := untranspose(a) 1544 bType, bTrans := untranspose(b) 1545 errStr := fmt.Sprintf("%T.%s(%T, %T), sizes: %#v, atrans %v, btrans %v", receiver, name, aType, bType, test, aTrans, bTrans) 1546 empty := makeRandOf(receiver, 0, 0, src) 1547 panicked, err := panics(func() { method(empty, a, b) }) 1548 if !dimsOK && !panicked { 1549 t.Errorf("Did not panic with illegal size: %s", errStr) 1550 continue 1551 } 1552 if dimsOK && panicked { 1553 t.Errorf("Panicked with legal size: %s: %v", errStr, err) 1554 continue 1555 } 1556 if !equal(a, aCopy) { 1557 t.Errorf("First input argument changed in call: %s", errStr) 1558 } 1559 if !equal(b, bCopy) { 1560 t.Errorf("Second input argument changed in call: %s", errStr) 1561 } 1562 if !dimsOK { 1563 continue 1564 } 1565 wasEmpty, empty := empty, nil // Nil-out empty so we detect illegal use. 1566 // NaN equality is allowed because of 0/0 in DivElem test. 1567 if !equalApprox(wasEmpty, &want, tol, true) { 1568 t.Errorf("Answer mismatch with empty receiver: %s", errStr) 1569 continue 1570 } 1571 1572 // Test the method with a non-empty-value of the receiver. 1573 // The receiver has been overwritten in place so use its size 1574 // to construct a new random matrix. 1575 rr, rc := wasEmpty.Dims() 1576 neverEmpty := makeRandOf(receiver, rr, rc, src) 1577 panicked, message := panics(func() { method(neverEmpty, a, b) }) 1578 if panicked { 1579 t.Errorf("Panicked with non-empty receiver: %s: %s", errStr, message) 1580 } 1581 // NaN equality is allowed because of 0/0 in DivElem test. 1582 if !equalApprox(neverEmpty, &want, tol, true) { 1583 t.Errorf("Answer mismatch non-empty receiver: %s", errStr) 1584 } 1585 1586 // Test the method with a NaN-filled value of the receiver. 1587 // The receiver has been overwritten in place so use its size 1588 // to construct a new NaN matrix. 1589 nanMatrix := makeNaNOf(receiver, rr, rc) 1590 panicked, message = panics(func() { method(nanMatrix, a, b) }) 1591 if panicked { 1592 t.Errorf("Panicked with NaN-filled receiver: %s: %s", errStr, message) 1593 } 1594 // NaN equality is allowed because of 0/0 in DivElem test. 1595 if !equalApprox(nanMatrix, &want, tol, true) { 1596 t.Errorf("Answer mismatch NaN-filled receiver: %s", errStr) 1597 } 1598 1599 // Test with an incorrectly sized matrix. 1600 switch receiver.(type) { 1601 default: 1602 panic("matrix type not coded for incorrect receiver size") 1603 case *Dense: 1604 wrongSize := makeRandOf(receiver, rr+1, rc, src) 1605 panicked, _ = panics(func() { method(wrongSize, a, b) }) 1606 if !panicked { 1607 t.Errorf("Did not panic with wrong number of rows: %s", errStr) 1608 } 1609 wrongSize = makeRandOf(receiver, rr, rc+1, src) 1610 panicked, _ = panics(func() { method(wrongSize, a, b) }) 1611 if !panicked { 1612 t.Errorf("Did not panic with wrong number of columns: %s", errStr) 1613 } 1614 case *TriDense, *SymDense: 1615 // Add to the square size. 1616 wrongSize := makeRandOf(receiver, rr+1, rc+1, src) 1617 panicked, _ = panics(func() { method(wrongSize, a, b) }) 1618 if !panicked { 1619 t.Errorf("Did not panic with wrong size: %s", errStr) 1620 } 1621 case *VecDense: 1622 // Add to the column length. 1623 wrongSize := makeRandOf(receiver, rr+1, rc, src) 1624 panicked, _ = panics(func() { method(wrongSize, a, b) }) 1625 if !panicked { 1626 t.Errorf("Did not panic with wrong number of rows: %s", errStr) 1627 } 1628 } 1629 1630 // The receiver and an input may share a matrix pointer 1631 // if the type and size of the receiver and one of the 1632 // arguments match. Test the method works properly 1633 // when this is the case. 1634 aMaybeSame := maybeSame(neverEmpty, a) 1635 bMaybeSame := maybeSame(neverEmpty, b) 1636 if aMaybeSame { 1637 aSame := makeCopyOf(a) 1638 receiver = aSame 1639 u, ok := aSame.(Untransposer) 1640 if ok { 1641 receiver = u.Untranspose() 1642 } 1643 preData := underlyingData(receiver) 1644 panicked, err = panics(func() { method(receiver, aSame, b) }) 1645 if panicked { 1646 t.Errorf("Panics when a maybeSame: %s: %v", errStr, err) 1647 } else { 1648 if !equalApprox(receiver, &want, tol, false) { 1649 t.Errorf("Wrong answer when a maybeSame: %s", errStr) 1650 } 1651 postData := underlyingData(receiver) 1652 if !floats.Equal(preData, postData) { 1653 t.Errorf("Original data slice not modified when a maybeSame: %s", errStr) 1654 } 1655 } 1656 } 1657 if bMaybeSame { 1658 bSame := makeCopyOf(b) 1659 receiver = bSame 1660 u, ok := bSame.(Untransposer) 1661 if ok { 1662 receiver = u.Untranspose() 1663 } 1664 preData := underlyingData(receiver) 1665 panicked, err = panics(func() { method(receiver, a, bSame) }) 1666 if panicked { 1667 t.Errorf("Panics when b maybeSame: %s: %v", errStr, err) 1668 } else { 1669 if !equalApprox(receiver, &want, tol, false) { 1670 t.Errorf("Wrong answer when b maybeSame: %s", errStr) 1671 } 1672 postData := underlyingData(receiver) 1673 if !floats.Equal(preData, postData) { 1674 t.Errorf("Original data slice not modified when b maybeSame: %s", errStr) 1675 } 1676 } 1677 } 1678 if aMaybeSame && bMaybeSame { 1679 aSame := makeCopyOf(a) 1680 receiver = aSame 1681 u, ok := aSame.(Untransposer) 1682 if ok { 1683 receiver = u.Untranspose() 1684 } 1685 // Ensure that b is the correct transpose type if applicable. 1686 // The receiver is always a concrete type so use it. 1687 bSame := receiver 1688 _, ok = b.(Untransposer) 1689 if ok { 1690 bSame = retranspose(b, receiver) 1691 } 1692 // Compute the real answer for this case. It is different 1693 // from the initial answer since now a and b have the 1694 // same data. 1695 empty = makeRandOf(wasEmpty, 0, 0, src) 1696 method(empty, aSame, bSame) 1697 wasEmpty, empty = empty, nil // Nil-out empty so we detect illegal use. 1698 preData := underlyingData(receiver) 1699 panicked, err = panics(func() { method(receiver, aSame, bSame) }) 1700 if panicked { 1701 t.Errorf("Panics when both maybeSame: %s: %v", errStr, err) 1702 } else { 1703 if !equalApprox(receiver, wasEmpty, tol, false) { 1704 t.Errorf("Wrong answer when both maybeSame: %s", errStr) 1705 } 1706 postData := underlyingData(receiver) 1707 if !floats.Equal(preData, postData) { 1708 t.Errorf("Original data slice not modified when both maybeSame: %s", errStr) 1709 } 1710 } 1711 } 1712 } 1713 } 1714 } 1715 }