github.com/gopherd/gonum@v0.0.4/blas/cblas64/conv_test.go (about) 1 // Code generated by "go generate github.com/gopherd/gonum/blas”; DO NOT EDIT. 2 3 // Copyright ©2015 The Gonum Authors. All rights reserved. 4 // Use of this source code is governed by a BSD-style 5 // license that can be found in the LICENSE file. 6 7 package cblas64 8 9 import ( 10 math "github.com/gopherd/gonum/internal/cmplx64" 11 "testing" 12 13 "github.com/gopherd/gonum/blas" 14 ) 15 16 func newGeneralFrom(a GeneralCols) General { 17 t := General{ 18 Rows: a.Rows, 19 Cols: a.Cols, 20 Stride: a.Cols, 21 Data: make([]complex64, a.Rows*a.Cols), 22 } 23 t.From(a) 24 return t 25 } 26 27 func (m General) dims() (r, c int) { return m.Rows, m.Cols } 28 func (m General) at(i, j int) complex64 { return m.Data[i*m.Stride+j] } 29 30 func newGeneralColsFrom(a General) GeneralCols { 31 t := GeneralCols{ 32 Rows: a.Rows, 33 Cols: a.Cols, 34 Stride: a.Rows, 35 Data: make([]complex64, a.Rows*a.Cols), 36 } 37 t.From(a) 38 return t 39 } 40 41 func (m GeneralCols) dims() (r, c int) { return m.Rows, m.Cols } 42 func (m GeneralCols) at(i, j int) complex64 { return m.Data[i+j*m.Stride] } 43 44 type general interface { 45 dims() (r, c int) 46 at(i, j int) complex64 47 } 48 49 func sameGeneral(a, b general) bool { 50 ar, ac := a.dims() 51 br, bc := b.dims() 52 if ar != br || ac != bc { 53 return false 54 } 55 for i := 0; i < ar; i++ { 56 for j := 0; j < ac; j++ { 57 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 58 return false 59 } 60 } 61 } 62 return true 63 } 64 65 var generalTests = []General{ 66 {Rows: 2, Cols: 3, Stride: 3, Data: []complex64{ 67 1, 2, 3, 68 4, 5, 6, 69 }}, 70 {Rows: 3, Cols: 2, Stride: 2, Data: []complex64{ 71 1, 2, 72 3, 4, 73 5, 6, 74 }}, 75 {Rows: 3, Cols: 3, Stride: 3, Data: []complex64{ 76 1, 2, 3, 77 4, 5, 6, 78 7, 8, 9, 79 }}, 80 {Rows: 2, Cols: 3, Stride: 5, Data: []complex64{ 81 1, 2, 3, 0, 0, 82 4, 5, 6, 0, 0, 83 }}, 84 {Rows: 3, Cols: 2, Stride: 5, Data: []complex64{ 85 1, 2, 0, 0, 0, 86 3, 4, 0, 0, 0, 87 5, 6, 0, 0, 0, 88 }}, 89 {Rows: 3, Cols: 3, Stride: 5, Data: []complex64{ 90 1, 2, 3, 0, 0, 91 4, 5, 6, 0, 0, 92 7, 8, 9, 0, 0, 93 }}, 94 } 95 96 func TestConvertGeneral(t *testing.T) { 97 for _, test := range generalTests { 98 colmajor := newGeneralColsFrom(test) 99 if !sameGeneral(colmajor, test) { 100 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 101 colmajor, test) 102 } 103 rowmajor := newGeneralFrom(colmajor) 104 if !sameGeneral(rowmajor, test) { 105 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 106 rowmajor, test) 107 } 108 } 109 } 110 111 func newTriangularFrom(a TriangularCols) Triangular { 112 t := Triangular{ 113 N: a.N, 114 Stride: a.N, 115 Data: make([]complex64, a.N*a.N), 116 Diag: a.Diag, 117 Uplo: a.Uplo, 118 } 119 t.From(a) 120 return t 121 } 122 123 func (m Triangular) n() int { return m.N } 124 func (m Triangular) at(i, j int) complex64 { 125 if m.Diag == blas.Unit && i == j { 126 return 1 127 } 128 if m.Uplo == blas.Lower && i < j && j < m.N { 129 return 0 130 } 131 if m.Uplo == blas.Upper && i > j { 132 return 0 133 } 134 return m.Data[i*m.Stride+j] 135 } 136 func (m Triangular) uplo() blas.Uplo { return m.Uplo } 137 func (m Triangular) diag() blas.Diag { return m.Diag } 138 139 func newTriangularColsFrom(a Triangular) TriangularCols { 140 t := TriangularCols{ 141 N: a.N, 142 Stride: a.N, 143 Data: make([]complex64, a.N*a.N), 144 Diag: a.Diag, 145 Uplo: a.Uplo, 146 } 147 t.From(a) 148 return t 149 } 150 151 func (m TriangularCols) n() int { return m.N } 152 func (m TriangularCols) at(i, j int) complex64 { 153 if m.Diag == blas.Unit && i == j { 154 return 1 155 } 156 if m.Uplo == blas.Lower && i < j { 157 return 0 158 } 159 if m.Uplo == blas.Upper && i > j && i < m.N { 160 return 0 161 } 162 return m.Data[i+j*m.Stride] 163 } 164 func (m TriangularCols) uplo() blas.Uplo { return m.Uplo } 165 func (m TriangularCols) diag() blas.Diag { return m.Diag } 166 167 type triangular interface { 168 n() int 169 at(i, j int) complex64 170 uplo() blas.Uplo 171 diag() blas.Diag 172 } 173 174 func sameTriangular(a, b triangular) bool { 175 an := a.n() 176 bn := b.n() 177 if an != bn { 178 return false 179 } 180 for i := 0; i < an; i++ { 181 for j := 0; j < an; j++ { 182 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 183 return false 184 } 185 } 186 } 187 return true 188 } 189 190 var triangularTests = []Triangular{ 191 {N: 3, Stride: 3, Data: []complex64{ 192 1, 2, 3, 193 4, 5, 6, 194 7, 8, 9, 195 }}, 196 {N: 3, Stride: 5, Data: []complex64{ 197 1, 2, 3, 0, 0, 198 4, 5, 6, 0, 0, 199 7, 8, 9, 0, 0, 200 }}, 201 } 202 203 func TestConvertTriangular(t *testing.T) { 204 for _, test := range triangularTests { 205 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower, blas.All} { 206 for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} { 207 test.Uplo = uplo 208 test.Diag = diag 209 colmajor := newTriangularColsFrom(test) 210 if !sameTriangular(colmajor, test) { 211 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 212 colmajor, test) 213 } 214 rowmajor := newTriangularFrom(colmajor) 215 if !sameTriangular(rowmajor, test) { 216 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 217 rowmajor, test) 218 } 219 } 220 } 221 } 222 } 223 224 func newBandFrom(a BandCols) Band { 225 t := Band{ 226 Rows: a.Rows, 227 Cols: a.Cols, 228 KL: a.KL, 229 KU: a.KU, 230 Stride: a.KL + a.KU + 1, 231 Data: make([]complex64, a.Rows*(a.KL+a.KU+1)), 232 } 233 for i := range t.Data { 234 t.Data[i] = math.NaN() 235 } 236 t.From(a) 237 return t 238 } 239 240 func (m Band) dims() (r, c int) { return m.Rows, m.Cols } 241 func (m Band) at(i, j int) complex64 { 242 pj := j + m.KL - i 243 if pj < 0 || m.KL+m.KU+1 <= pj { 244 return 0 245 } 246 return m.Data[i*m.Stride+pj] 247 } 248 func (m Band) bandwidth() (kl, ku int) { return m.KL, m.KU } 249 250 func newBandColsFrom(a Band) BandCols { 251 t := BandCols{ 252 Rows: a.Rows, 253 Cols: a.Cols, 254 KL: a.KL, 255 KU: a.KU, 256 Stride: a.KL + a.KU + 1, 257 Data: make([]complex64, a.Cols*(a.KL+a.KU+1)), 258 } 259 for i := range t.Data { 260 t.Data[i] = math.NaN() 261 } 262 t.From(a) 263 return t 264 } 265 266 func (m BandCols) dims() (r, c int) { return m.Rows, m.Cols } 267 func (m BandCols) at(i, j int) complex64 { 268 pj := i + m.KU - j 269 if pj < 0 || m.KL+m.KU+1 <= pj { 270 return 0 271 } 272 return m.Data[j*m.Stride+pj] 273 } 274 func (m BandCols) bandwidth() (kl, ku int) { return m.KL, m.KU } 275 276 type band interface { 277 dims() (r, c int) 278 at(i, j int) complex64 279 bandwidth() (kl, ku int) 280 } 281 282 func sameBand(a, b band) bool { 283 ar, ac := a.dims() 284 br, bc := b.dims() 285 if ar != br || ac != bc { 286 return false 287 } 288 akl, aku := a.bandwidth() 289 bkl, bku := b.bandwidth() 290 if akl != bkl || aku != bku { 291 return false 292 } 293 for i := 0; i < ar; i++ { 294 for j := 0; j < ac; j++ { 295 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 296 return false 297 } 298 } 299 } 300 return true 301 } 302 303 var bandTests = []Band{ 304 {Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 1, Data: []complex64{ 305 1, 306 2, 307 3, 308 }}, 309 {Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []complex64{ 310 1, 311 2, 312 3, 313 }}, 314 {Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []complex64{ 315 1, 316 2, 317 3, 318 }}, 319 {Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 2, Data: []complex64{ 320 1, 2, 321 3, 4, 322 5, 6, 323 }}, 324 {Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 2, Data: []complex64{ 325 1, 2, 326 3, 4, 327 5, 6, 328 }}, 329 {Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 3, Data: []complex64{ 330 -1, 2, 3, 331 4, 5, 6, 332 7, 8, 9, 333 }}, 334 {Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 3, Data: []complex64{ 335 -1, 2, 3, 336 4, 5, 6, 337 7, 8, -2, 338 9, -3, -4, 339 }}, 340 {Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 4, Data: []complex64{ 341 -2, -1, 3, 4, 342 -3, 5, 6, 7, 343 8, 9, 10, 11, 344 }}, 345 {Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 4, Data: []complex64{ 346 -2, -1, 2, 3, 347 -3, 4, 5, 6, 348 7, 8, 9, -4, 349 10, 11, -5, -6, 350 }}, 351 352 {Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 5, Data: []complex64{ 353 1, 0, 0, 0, 0, 354 2, 0, 0, 0, 0, 355 3, 0, 0, 0, 0, 356 }}, 357 {Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []complex64{ 358 1, 0, 0, 0, 0, 359 2, 0, 0, 0, 0, 360 3, 0, 0, 0, 0, 361 }}, 362 {Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []complex64{ 363 1, 0, 0, 0, 0, 364 2, 0, 0, 0, 0, 365 3, 0, 0, 0, 0, 366 }}, 367 {Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 5, Data: []complex64{ 368 1, 2, 0, 0, 0, 369 3, 4, 0, 0, 0, 370 5, 6, 0, 0, 0, 371 }}, 372 {Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 5, Data: []complex64{ 373 1, 2, 0, 0, 0, 374 3, 4, 0, 0, 0, 375 5, 6, 0, 0, 0, 376 }}, 377 {Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 5, Data: []complex64{ 378 -1, 2, 3, 0, 0, 379 4, 5, 6, 0, 0, 380 7, 8, 9, 0, 0, 381 }}, 382 {Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 5, Data: []complex64{ 383 -1, 2, 3, 0, 0, 384 4, 5, 6, 0, 0, 385 7, 8, -2, 0, 0, 386 9, -3, -4, 0, 0, 387 }}, 388 {Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 5, Data: []complex64{ 389 -2, -1, 3, 4, 0, 390 -3, 5, 6, 7, 0, 391 8, 9, 10, 11, 0, 392 }}, 393 {Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 5, Data: []complex64{ 394 -2, -1, 2, 3, 0, 395 -3, 4, 5, 6, 0, 396 7, 8, 9, -4, 0, 397 10, 11, -5, -6, 0, 398 }}, 399 } 400 401 func TestConvertBand(t *testing.T) { 402 for _, test := range bandTests { 403 colmajor := newBandColsFrom(test) 404 if !sameBand(colmajor, test) { 405 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 406 colmajor, test) 407 } 408 rowmajor := newBandFrom(colmajor) 409 if !sameBand(rowmajor, test) { 410 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 411 rowmajor, test) 412 } 413 } 414 } 415 416 func newTriangularBandFrom(a TriangularBandCols) TriangularBand { 417 t := TriangularBand{ 418 N: a.N, 419 K: a.K, 420 Stride: a.K + 1, 421 Data: make([]complex64, a.N*(a.K+1)), 422 Uplo: a.Uplo, 423 Diag: a.Diag, 424 } 425 for i := range t.Data { 426 t.Data[i] = math.NaN() 427 } 428 t.From(a) 429 return t 430 } 431 432 func (m TriangularBand) n() (n int) { return m.N } 433 func (m TriangularBand) at(i, j int) complex64 { 434 if m.Diag == blas.Unit && i == j { 435 return 1 436 } 437 b := Band{ 438 Rows: m.N, Cols: m.N, 439 Stride: m.Stride, 440 Data: m.Data, 441 } 442 switch m.Uplo { 443 default: 444 panic("cblas64: bad BLAS uplo") 445 case blas.Upper: 446 if i > j { 447 return 0 448 } 449 b.KU = m.K 450 case blas.Lower: 451 if i < j { 452 return 0 453 } 454 b.KL = m.K 455 } 456 return b.at(i, j) 457 } 458 func (m TriangularBand) bandwidth() (k int) { return m.K } 459 func (m TriangularBand) uplo() blas.Uplo { return m.Uplo } 460 func (m TriangularBand) diag() blas.Diag { return m.Diag } 461 462 func newTriangularBandColsFrom(a TriangularBand) TriangularBandCols { 463 t := TriangularBandCols{ 464 N: a.N, 465 K: a.K, 466 Stride: a.K + 1, 467 Data: make([]complex64, a.N*(a.K+1)), 468 Uplo: a.Uplo, 469 Diag: a.Diag, 470 } 471 for i := range t.Data { 472 t.Data[i] = math.NaN() 473 } 474 t.From(a) 475 return t 476 } 477 478 func (m TriangularBandCols) n() (n int) { return m.N } 479 func (m TriangularBandCols) at(i, j int) complex64 { 480 if m.Diag == blas.Unit && i == j { 481 return 1 482 } 483 b := BandCols{ 484 Rows: m.N, Cols: m.N, 485 Stride: m.Stride, 486 Data: m.Data, 487 } 488 switch m.Uplo { 489 default: 490 panic("cblas64: bad BLAS uplo") 491 case blas.Upper: 492 if i > j { 493 return 0 494 } 495 b.KU = m.K 496 case blas.Lower: 497 if i < j { 498 return 0 499 } 500 b.KL = m.K 501 } 502 return b.at(i, j) 503 } 504 func (m TriangularBandCols) bandwidth() (k int) { return m.K } 505 func (m TriangularBandCols) uplo() blas.Uplo { return m.Uplo } 506 func (m TriangularBandCols) diag() blas.Diag { return m.Diag } 507 508 type triangularBand interface { 509 n() (n int) 510 at(i, j int) complex64 511 bandwidth() (k int) 512 uplo() blas.Uplo 513 diag() blas.Diag 514 } 515 516 func sameTriangularBand(a, b triangularBand) bool { 517 an := a.n() 518 bn := b.n() 519 if an != bn { 520 return false 521 } 522 if a.uplo() != b.uplo() { 523 return false 524 } 525 if a.diag() != b.diag() { 526 return false 527 } 528 ak := a.bandwidth() 529 bk := b.bandwidth() 530 if ak != bk { 531 return false 532 } 533 for i := 0; i < an; i++ { 534 for j := 0; j < an; j++ { 535 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 536 return false 537 } 538 } 539 } 540 return true 541 } 542 543 var triangularBandTests = []TriangularBand{ 544 {N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []complex64{ 545 1, 546 2, 547 3, 548 }}, 549 {N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []complex64{ 550 1, 551 2, 552 3, 553 }}, 554 {N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []complex64{ 555 1, 2, 556 3, 4, 557 5, -1, 558 }}, 559 {N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []complex64{ 560 -1, 1, 561 2, 3, 562 4, 5, 563 }}, 564 {N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []complex64{ 565 1, 2, 3, 566 4, 5, -1, 567 6, -2, -3, 568 }}, 569 {N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []complex64{ 570 -2, -1, 1, 571 -3, 2, 4, 572 3, 5, 6, 573 }}, 574 575 {N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []complex64{ 576 1, 0, 0, 0, 0, 577 2, 0, 0, 0, 0, 578 3, 0, 0, 0, 0, 579 }}, 580 {N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []complex64{ 581 1, 0, 0, 0, 0, 582 2, 0, 0, 0, 0, 583 3, 0, 0, 0, 0, 584 }}, 585 {N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []complex64{ 586 1, 2, 0, 0, 0, 587 3, 4, 0, 0, 0, 588 5, -1, 0, 0, 0, 589 }}, 590 {N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []complex64{ 591 -1, 1, 0, 0, 0, 592 2, 3, 0, 0, 0, 593 4, 5, 0, 0, 0, 594 }}, 595 {N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []complex64{ 596 1, 2, 3, 0, 0, 597 4, 5, -1, 0, 0, 598 6, -2, -3, 0, 0, 599 }}, 600 {N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []complex64{ 601 -2, -1, 1, 0, 0, 602 -3, 2, 4, 0, 0, 603 3, 5, 6, 0, 0, 604 }}, 605 } 606 607 func TestConvertTriBand(t *testing.T) { 608 for _, test := range triangularBandTests { 609 colmajor := newTriangularBandColsFrom(test) 610 if !sameTriangularBand(colmajor, test) { 611 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 612 colmajor, test) 613 } 614 rowmajor := newTriangularBandFrom(colmajor) 615 if !sameTriangularBand(rowmajor, test) { 616 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 617 rowmajor, test) 618 } 619 } 620 }