gonum.org/v1/gonum@v0.14.0/blas/blas64/conv_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package blas64 6 7 import ( 8 "math" 9 "testing" 10 11 "gonum.org/v1/gonum/blas" 12 ) 13 14 func newGeneralFrom(a GeneralCols) General { 15 t := General{ 16 Rows: a.Rows, 17 Cols: a.Cols, 18 Stride: a.Cols, 19 Data: make([]float64, a.Rows*a.Cols), 20 } 21 t.From(a) 22 return t 23 } 24 25 func (m General) dims() (r, c int) { return m.Rows, m.Cols } 26 func (m General) at(i, j int) float64 { return m.Data[i*m.Stride+j] } 27 28 func newGeneralColsFrom(a General) GeneralCols { 29 t := GeneralCols{ 30 Rows: a.Rows, 31 Cols: a.Cols, 32 Stride: a.Rows, 33 Data: make([]float64, a.Rows*a.Cols), 34 } 35 t.From(a) 36 return t 37 } 38 39 func (m GeneralCols) dims() (r, c int) { return m.Rows, m.Cols } 40 func (m GeneralCols) at(i, j int) float64 { return m.Data[i+j*m.Stride] } 41 42 type general interface { 43 dims() (r, c int) 44 at(i, j int) float64 45 } 46 47 func sameGeneral(a, b general) bool { 48 ar, ac := a.dims() 49 br, bc := b.dims() 50 if ar != br || ac != bc { 51 return false 52 } 53 for i := 0; i < ar; i++ { 54 for j := 0; j < ac; j++ { 55 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 56 return false 57 } 58 } 59 } 60 return true 61 } 62 63 var generalTests = []General{ 64 {Rows: 2, Cols: 3, Stride: 3, Data: []float64{ 65 1, 2, 3, 66 4, 5, 6, 67 }}, 68 {Rows: 3, Cols: 2, Stride: 2, Data: []float64{ 69 1, 2, 70 3, 4, 71 5, 6, 72 }}, 73 {Rows: 3, Cols: 3, Stride: 3, Data: []float64{ 74 1, 2, 3, 75 4, 5, 6, 76 7, 8, 9, 77 }}, 78 {Rows: 2, Cols: 3, Stride: 5, Data: []float64{ 79 1, 2, 3, 0, 0, 80 4, 5, 6, 0, 0, 81 }}, 82 {Rows: 3, Cols: 2, Stride: 5, Data: []float64{ 83 1, 2, 0, 0, 0, 84 3, 4, 0, 0, 0, 85 5, 6, 0, 0, 0, 86 }}, 87 {Rows: 3, Cols: 3, Stride: 5, Data: []float64{ 88 1, 2, 3, 0, 0, 89 4, 5, 6, 0, 0, 90 7, 8, 9, 0, 0, 91 }}, 92 } 93 94 func TestConvertGeneral(t *testing.T) { 95 for _, test := range generalTests { 96 colmajor := newGeneralColsFrom(test) 97 if !sameGeneral(colmajor, test) { 98 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 99 colmajor, test) 100 } 101 rowmajor := newGeneralFrom(colmajor) 102 if !sameGeneral(rowmajor, test) { 103 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 104 rowmajor, test) 105 } 106 } 107 } 108 109 func newTriangularFrom(a TriangularCols) Triangular { 110 t := Triangular{ 111 N: a.N, 112 Stride: a.N, 113 Data: make([]float64, a.N*a.N), 114 Diag: a.Diag, 115 Uplo: a.Uplo, 116 } 117 t.From(a) 118 return t 119 } 120 121 func (m Triangular) n() int { return m.N } 122 func (m Triangular) at(i, j int) float64 { 123 if m.Diag == blas.Unit && i == j { 124 return 1 125 } 126 if m.Uplo == blas.Lower && i < j && j < m.N { 127 return 0 128 } 129 if m.Uplo == blas.Upper && i > j { 130 return 0 131 } 132 return m.Data[i*m.Stride+j] 133 } 134 func (m Triangular) uplo() blas.Uplo { return m.Uplo } 135 func (m Triangular) diag() blas.Diag { return m.Diag } 136 137 func newTriangularColsFrom(a Triangular) TriangularCols { 138 t := TriangularCols{ 139 N: a.N, 140 Stride: a.N, 141 Data: make([]float64, a.N*a.N), 142 Diag: a.Diag, 143 Uplo: a.Uplo, 144 } 145 t.From(a) 146 return t 147 } 148 149 func (m TriangularCols) n() int { return m.N } 150 func (m TriangularCols) at(i, j int) float64 { 151 if m.Diag == blas.Unit && i == j { 152 return 1 153 } 154 if m.Uplo == blas.Lower && i < j { 155 return 0 156 } 157 if m.Uplo == blas.Upper && i > j && i < m.N { 158 return 0 159 } 160 return m.Data[i+j*m.Stride] 161 } 162 func (m TriangularCols) uplo() blas.Uplo { return m.Uplo } 163 func (m TriangularCols) diag() blas.Diag { return m.Diag } 164 165 type triangular interface { 166 n() int 167 at(i, j int) float64 168 uplo() blas.Uplo 169 diag() blas.Diag 170 } 171 172 func sameTriangular(a, b triangular) bool { 173 an := a.n() 174 bn := b.n() 175 if an != bn { 176 return false 177 } 178 for i := 0; i < an; i++ { 179 for j := 0; j < an; j++ { 180 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 181 return false 182 } 183 } 184 } 185 return true 186 } 187 188 var triangularTests = []Triangular{ 189 {N: 3, Stride: 3, Data: []float64{ 190 1, 2, 3, 191 4, 5, 6, 192 7, 8, 9, 193 }}, 194 {N: 3, Stride: 5, Data: []float64{ 195 1, 2, 3, 0, 0, 196 4, 5, 6, 0, 0, 197 7, 8, 9, 0, 0, 198 }}, 199 } 200 201 func TestConvertTriangular(t *testing.T) { 202 for _, test := range triangularTests { 203 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower, blas.All} { 204 for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} { 205 test.Uplo = uplo 206 test.Diag = diag 207 colmajor := newTriangularColsFrom(test) 208 if !sameTriangular(colmajor, test) { 209 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 210 colmajor, test) 211 } 212 rowmajor := newTriangularFrom(colmajor) 213 if !sameTriangular(rowmajor, test) { 214 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 215 rowmajor, test) 216 } 217 } 218 } 219 } 220 } 221 222 func newBandFrom(a BandCols) Band { 223 t := Band{ 224 Rows: a.Rows, 225 Cols: a.Cols, 226 KL: a.KL, 227 KU: a.KU, 228 Stride: a.KL + a.KU + 1, 229 Data: make([]float64, a.Rows*(a.KL+a.KU+1)), 230 } 231 for i := range t.Data { 232 t.Data[i] = math.NaN() 233 } 234 t.From(a) 235 return t 236 } 237 238 func (m Band) dims() (r, c int) { return m.Rows, m.Cols } 239 func (m Band) at(i, j int) float64 { 240 pj := j + m.KL - i 241 if pj < 0 || m.KL+m.KU+1 <= pj { 242 return 0 243 } 244 return m.Data[i*m.Stride+pj] 245 } 246 func (m Band) bandwidth() (kl, ku int) { return m.KL, m.KU } 247 248 func newBandColsFrom(a Band) BandCols { 249 t := BandCols{ 250 Rows: a.Rows, 251 Cols: a.Cols, 252 KL: a.KL, 253 KU: a.KU, 254 Stride: a.KL + a.KU + 1, 255 Data: make([]float64, a.Cols*(a.KL+a.KU+1)), 256 } 257 for i := range t.Data { 258 t.Data[i] = math.NaN() 259 } 260 t.From(a) 261 return t 262 } 263 264 func (m BandCols) dims() (r, c int) { return m.Rows, m.Cols } 265 func (m BandCols) at(i, j int) float64 { 266 pj := i + m.KU - j 267 if pj < 0 || m.KL+m.KU+1 <= pj { 268 return 0 269 } 270 return m.Data[j*m.Stride+pj] 271 } 272 func (m BandCols) bandwidth() (kl, ku int) { return m.KL, m.KU } 273 274 type band interface { 275 dims() (r, c int) 276 at(i, j int) float64 277 bandwidth() (kl, ku int) 278 } 279 280 func sameBand(a, b band) bool { 281 ar, ac := a.dims() 282 br, bc := b.dims() 283 if ar != br || ac != bc { 284 return false 285 } 286 akl, aku := a.bandwidth() 287 bkl, bku := b.bandwidth() 288 if akl != bkl || aku != bku { 289 return false 290 } 291 for i := 0; i < ar; i++ { 292 for j := 0; j < ac; j++ { 293 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 294 return false 295 } 296 } 297 } 298 return true 299 } 300 301 var bandTests = []Band{ 302 {Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 1, Data: []float64{ 303 1, 304 2, 305 3, 306 }}, 307 {Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []float64{ 308 1, 309 2, 310 3, 311 }}, 312 {Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []float64{ 313 1, 314 2, 315 3, 316 }}, 317 {Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 2, Data: []float64{ 318 1, 2, 319 3, 4, 320 5, 6, 321 }}, 322 {Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 2, Data: []float64{ 323 1, 2, 324 3, 4, 325 5, 6, 326 }}, 327 {Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 3, Data: []float64{ 328 -1, 2, 3, 329 4, 5, 6, 330 7, 8, 9, 331 }}, 332 {Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 3, Data: []float64{ 333 -1, 2, 3, 334 4, 5, 6, 335 7, 8, -2, 336 9, -3, -4, 337 }}, 338 {Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 4, Data: []float64{ 339 -2, -1, 3, 4, 340 -3, 5, 6, 7, 341 8, 9, 10, 11, 342 }}, 343 {Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 4, Data: []float64{ 344 -2, -1, 2, 3, 345 -3, 4, 5, 6, 346 7, 8, 9, -4, 347 10, 11, -5, -6, 348 }}, 349 350 {Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 5, Data: []float64{ 351 1, 0, 0, 0, 0, 352 2, 0, 0, 0, 0, 353 3, 0, 0, 0, 0, 354 }}, 355 {Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []float64{ 356 1, 0, 0, 0, 0, 357 2, 0, 0, 0, 0, 358 3, 0, 0, 0, 0, 359 }}, 360 {Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []float64{ 361 1, 0, 0, 0, 0, 362 2, 0, 0, 0, 0, 363 3, 0, 0, 0, 0, 364 }}, 365 {Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 5, Data: []float64{ 366 1, 2, 0, 0, 0, 367 3, 4, 0, 0, 0, 368 5, 6, 0, 0, 0, 369 }}, 370 {Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 5, Data: []float64{ 371 1, 2, 0, 0, 0, 372 3, 4, 0, 0, 0, 373 5, 6, 0, 0, 0, 374 }}, 375 {Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 5, Data: []float64{ 376 -1, 2, 3, 0, 0, 377 4, 5, 6, 0, 0, 378 7, 8, 9, 0, 0, 379 }}, 380 {Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 5, Data: []float64{ 381 -1, 2, 3, 0, 0, 382 4, 5, 6, 0, 0, 383 7, 8, -2, 0, 0, 384 9, -3, -4, 0, 0, 385 }}, 386 {Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 5, Data: []float64{ 387 -2, -1, 3, 4, 0, 388 -3, 5, 6, 7, 0, 389 8, 9, 10, 11, 0, 390 }}, 391 {Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 5, Data: []float64{ 392 -2, -1, 2, 3, 0, 393 -3, 4, 5, 6, 0, 394 7, 8, 9, -4, 0, 395 10, 11, -5, -6, 0, 396 }}, 397 } 398 399 func TestConvertBand(t *testing.T) { 400 for _, test := range bandTests { 401 colmajor := newBandColsFrom(test) 402 if !sameBand(colmajor, test) { 403 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 404 colmajor, test) 405 } 406 rowmajor := newBandFrom(colmajor) 407 if !sameBand(rowmajor, test) { 408 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 409 rowmajor, test) 410 } 411 } 412 } 413 414 func newTriangularBandFrom(a TriangularBandCols) TriangularBand { 415 t := TriangularBand{ 416 N: a.N, 417 K: a.K, 418 Stride: a.K + 1, 419 Data: make([]float64, a.N*(a.K+1)), 420 Uplo: a.Uplo, 421 Diag: a.Diag, 422 } 423 for i := range t.Data { 424 t.Data[i] = math.NaN() 425 } 426 t.From(a) 427 return t 428 } 429 430 func (m TriangularBand) n() (n int) { return m.N } 431 func (m TriangularBand) at(i, j int) float64 { 432 if m.Diag == blas.Unit && i == j { 433 return 1 434 } 435 b := Band{ 436 Rows: m.N, Cols: m.N, 437 Stride: m.Stride, 438 Data: m.Data, 439 } 440 switch m.Uplo { 441 default: 442 panic("blas64: bad BLAS uplo") 443 case blas.Upper: 444 if i > j { 445 return 0 446 } 447 b.KU = m.K 448 case blas.Lower: 449 if i < j { 450 return 0 451 } 452 b.KL = m.K 453 } 454 return b.at(i, j) 455 } 456 func (m TriangularBand) bandwidth() (k int) { return m.K } 457 func (m TriangularBand) uplo() blas.Uplo { return m.Uplo } 458 func (m TriangularBand) diag() blas.Diag { return m.Diag } 459 460 func newTriangularBandColsFrom(a TriangularBand) TriangularBandCols { 461 t := TriangularBandCols{ 462 N: a.N, 463 K: a.K, 464 Stride: a.K + 1, 465 Data: make([]float64, a.N*(a.K+1)), 466 Uplo: a.Uplo, 467 Diag: a.Diag, 468 } 469 for i := range t.Data { 470 t.Data[i] = math.NaN() 471 } 472 t.From(a) 473 return t 474 } 475 476 func (m TriangularBandCols) n() (n int) { return m.N } 477 func (m TriangularBandCols) at(i, j int) float64 { 478 if m.Diag == blas.Unit && i == j { 479 return 1 480 } 481 b := BandCols{ 482 Rows: m.N, Cols: m.N, 483 Stride: m.Stride, 484 Data: m.Data, 485 } 486 switch m.Uplo { 487 default: 488 panic("blas64: bad BLAS uplo") 489 case blas.Upper: 490 if i > j { 491 return 0 492 } 493 b.KU = m.K 494 case blas.Lower: 495 if i < j { 496 return 0 497 } 498 b.KL = m.K 499 } 500 return b.at(i, j) 501 } 502 func (m TriangularBandCols) bandwidth() (k int) { return m.K } 503 func (m TriangularBandCols) uplo() blas.Uplo { return m.Uplo } 504 func (m TriangularBandCols) diag() blas.Diag { return m.Diag } 505 506 type triangularBand interface { 507 n() (n int) 508 at(i, j int) float64 509 bandwidth() (k int) 510 uplo() blas.Uplo 511 diag() blas.Diag 512 } 513 514 func sameTriangularBand(a, b triangularBand) bool { 515 an := a.n() 516 bn := b.n() 517 if an != bn { 518 return false 519 } 520 if a.uplo() != b.uplo() { 521 return false 522 } 523 if a.diag() != b.diag() { 524 return false 525 } 526 ak := a.bandwidth() 527 bk := b.bandwidth() 528 if ak != bk { 529 return false 530 } 531 for i := 0; i < an; i++ { 532 for j := 0; j < an; j++ { 533 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 534 return false 535 } 536 } 537 } 538 return true 539 } 540 541 var triangularBandTests = []TriangularBand{ 542 {N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float64{ 543 1, 544 2, 545 3, 546 }}, 547 {N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float64{ 548 1, 549 2, 550 3, 551 }}, 552 {N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float64{ 553 1, 2, 554 3, 4, 555 5, -1, 556 }}, 557 {N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float64{ 558 -1, 1, 559 2, 3, 560 4, 5, 561 }}, 562 {N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float64{ 563 1, 2, 3, 564 4, 5, -1, 565 6, -2, -3, 566 }}, 567 {N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float64{ 568 -2, -1, 1, 569 -3, 2, 4, 570 3, 5, 6, 571 }}, 572 573 {N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float64{ 574 1, 0, 0, 0, 0, 575 2, 0, 0, 0, 0, 576 3, 0, 0, 0, 0, 577 }}, 578 {N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float64{ 579 1, 0, 0, 0, 0, 580 2, 0, 0, 0, 0, 581 3, 0, 0, 0, 0, 582 }}, 583 {N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float64{ 584 1, 2, 0, 0, 0, 585 3, 4, 0, 0, 0, 586 5, -1, 0, 0, 0, 587 }}, 588 {N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float64{ 589 -1, 1, 0, 0, 0, 590 2, 3, 0, 0, 0, 591 4, 5, 0, 0, 0, 592 }}, 593 {N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float64{ 594 1, 2, 3, 0, 0, 595 4, 5, -1, 0, 0, 596 6, -2, -3, 0, 0, 597 }}, 598 {N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float64{ 599 -2, -1, 1, 0, 0, 600 -3, 2, 4, 0, 0, 601 3, 5, 6, 0, 0, 602 }}, 603 } 604 605 func TestConvertTriBand(t *testing.T) { 606 for _, test := range triangularBandTests { 607 colmajor := newTriangularBandColsFrom(test) 608 if !sameTriangularBand(colmajor, test) { 609 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 610 colmajor, test) 611 } 612 rowmajor := newTriangularBandFrom(colmajor) 613 if !sameTriangularBand(rowmajor, test) { 614 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 615 rowmajor, test) 616 } 617 } 618 }