github.com/gopherd/gonum@v0.0.4/blas/gonum/level3cmplx128.go (about) 1 // Copyright ©2019 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "math/cmplx" 9 10 "github.com/gopherd/gonum/blas" 11 "github.com/gopherd/gonum/internal/asm/c128" 12 ) 13 14 var _ blas.Complex128Level3 = Implementation{} 15 16 // Zgemm performs one of the matrix-matrix operations 17 // C = alpha * op(A) * op(B) + beta * C 18 // where op(X) is one of 19 // op(X) = X or op(X) = Xᵀ or op(X) = Xᴴ, 20 // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix, 21 // op(B) a k×n matrix and C an m×n matrix. 22 func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 23 switch tA { 24 default: 25 panic(badTranspose) 26 case blas.NoTrans, blas.Trans, blas.ConjTrans: 27 } 28 switch tB { 29 default: 30 panic(badTranspose) 31 case blas.NoTrans, blas.Trans, blas.ConjTrans: 32 } 33 switch { 34 case m < 0: 35 panic(mLT0) 36 case n < 0: 37 panic(nLT0) 38 case k < 0: 39 panic(kLT0) 40 } 41 rowA, colA := m, k 42 if tA != blas.NoTrans { 43 rowA, colA = k, m 44 } 45 if lda < max(1, colA) { 46 panic(badLdA) 47 } 48 rowB, colB := k, n 49 if tB != blas.NoTrans { 50 rowB, colB = n, k 51 } 52 if ldb < max(1, colB) { 53 panic(badLdB) 54 } 55 if ldc < max(1, n) { 56 panic(badLdC) 57 } 58 59 // Quick return if possible. 60 if m == 0 || n == 0 { 61 return 62 } 63 64 // For zero matrix size the following slice length checks are trivially satisfied. 65 if len(a) < (rowA-1)*lda+colA { 66 panic(shortA) 67 } 68 if len(b) < (rowB-1)*ldb+colB { 69 panic(shortB) 70 } 71 if len(c) < (m-1)*ldc+n { 72 panic(shortC) 73 } 74 75 // Quick return if possible. 76 if (alpha == 0 || k == 0) && beta == 1 { 77 return 78 } 79 80 if alpha == 0 { 81 if beta == 0 { 82 for i := 0; i < m; i++ { 83 for j := 0; j < n; j++ { 84 c[i*ldc+j] = 0 85 } 86 } 87 } else { 88 for i := 0; i < m; i++ { 89 for j := 0; j < n; j++ { 90 c[i*ldc+j] *= beta 91 } 92 } 93 } 94 return 95 } 96 97 switch tA { 98 case blas.NoTrans: 99 switch tB { 100 case blas.NoTrans: 101 // Form C = alpha * A * B + beta * C. 102 for i := 0; i < m; i++ { 103 switch { 104 case beta == 0: 105 for j := 0; j < n; j++ { 106 c[i*ldc+j] = 0 107 } 108 case beta != 1: 109 for j := 0; j < n; j++ { 110 c[i*ldc+j] *= beta 111 } 112 } 113 for l := 0; l < k; l++ { 114 tmp := alpha * a[i*lda+l] 115 for j := 0; j < n; j++ { 116 c[i*ldc+j] += tmp * b[l*ldb+j] 117 } 118 } 119 } 120 case blas.Trans: 121 // Form C = alpha * A * Bᵀ + beta * C. 122 for i := 0; i < m; i++ { 123 switch { 124 case beta == 0: 125 for j := 0; j < n; j++ { 126 c[i*ldc+j] = 0 127 } 128 case beta != 1: 129 for j := 0; j < n; j++ { 130 c[i*ldc+j] *= beta 131 } 132 } 133 for l := 0; l < k; l++ { 134 tmp := alpha * a[i*lda+l] 135 for j := 0; j < n; j++ { 136 c[i*ldc+j] += tmp * b[j*ldb+l] 137 } 138 } 139 } 140 case blas.ConjTrans: 141 // Form C = alpha * A * Bᴴ + beta * C. 142 for i := 0; i < m; i++ { 143 switch { 144 case beta == 0: 145 for j := 0; j < n; j++ { 146 c[i*ldc+j] = 0 147 } 148 case beta != 1: 149 for j := 0; j < n; j++ { 150 c[i*ldc+j] *= beta 151 } 152 } 153 for l := 0; l < k; l++ { 154 tmp := alpha * a[i*lda+l] 155 for j := 0; j < n; j++ { 156 c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l]) 157 } 158 } 159 } 160 } 161 case blas.Trans: 162 switch tB { 163 case blas.NoTrans: 164 // Form C = alpha * Aᵀ * B + beta * C. 165 for i := 0; i < m; i++ { 166 for j := 0; j < n; j++ { 167 var tmp complex128 168 for l := 0; l < k; l++ { 169 tmp += a[l*lda+i] * b[l*ldb+j] 170 } 171 if beta == 0 { 172 c[i*ldc+j] = alpha * tmp 173 } else { 174 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 175 } 176 } 177 } 178 case blas.Trans: 179 // Form C = alpha * Aᵀ * Bᵀ + beta * C. 180 for i := 0; i < m; i++ { 181 for j := 0; j < n; j++ { 182 var tmp complex128 183 for l := 0; l < k; l++ { 184 tmp += a[l*lda+i] * b[j*ldb+l] 185 } 186 if beta == 0 { 187 c[i*ldc+j] = alpha * tmp 188 } else { 189 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 190 } 191 } 192 } 193 case blas.ConjTrans: 194 // Form C = alpha * Aᵀ * Bᴴ + beta * C. 195 for i := 0; i < m; i++ { 196 for j := 0; j < n; j++ { 197 var tmp complex128 198 for l := 0; l < k; l++ { 199 tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l]) 200 } 201 if beta == 0 { 202 c[i*ldc+j] = alpha * tmp 203 } else { 204 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 205 } 206 } 207 } 208 } 209 case blas.ConjTrans: 210 switch tB { 211 case blas.NoTrans: 212 // Form C = alpha * Aᴴ * B + beta * C. 213 for i := 0; i < m; i++ { 214 for j := 0; j < n; j++ { 215 var tmp complex128 216 for l := 0; l < k; l++ { 217 tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j] 218 } 219 if beta == 0 { 220 c[i*ldc+j] = alpha * tmp 221 } else { 222 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 223 } 224 } 225 } 226 case blas.Trans: 227 // Form C = alpha * Aᴴ * Bᵀ + beta * C. 228 for i := 0; i < m; i++ { 229 for j := 0; j < n; j++ { 230 var tmp complex128 231 for l := 0; l < k; l++ { 232 tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l] 233 } 234 if beta == 0 { 235 c[i*ldc+j] = alpha * tmp 236 } else { 237 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 238 } 239 } 240 } 241 case blas.ConjTrans: 242 // Form C = alpha * Aᴴ * Bᴴ + beta * C. 243 for i := 0; i < m; i++ { 244 for j := 0; j < n; j++ { 245 var tmp complex128 246 for l := 0; l < k; l++ { 247 tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l]) 248 } 249 if beta == 0 { 250 c[i*ldc+j] = alpha * tmp 251 } else { 252 c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j] 253 } 254 } 255 } 256 } 257 } 258 } 259 260 // Zhemm performs one of the matrix-matrix operations 261 // C = alpha*A*B + beta*C if side == blas.Left 262 // C = alpha*B*A + beta*C if side == blas.Right 263 // where alpha and beta are scalars, A is an m×m or n×n hermitian matrix and B 264 // and C are m×n matrices. The imaginary parts of the diagonal elements of A are 265 // assumed to be zero. 266 func (Implementation) Zhemm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 267 na := m 268 if side == blas.Right { 269 na = n 270 } 271 switch { 272 case side != blas.Left && side != blas.Right: 273 panic(badSide) 274 case uplo != blas.Lower && uplo != blas.Upper: 275 panic(badUplo) 276 case m < 0: 277 panic(mLT0) 278 case n < 0: 279 panic(nLT0) 280 case lda < max(1, na): 281 panic(badLdA) 282 case ldb < max(1, n): 283 panic(badLdB) 284 case ldc < max(1, n): 285 panic(badLdC) 286 } 287 288 // Quick return if possible. 289 if m == 0 || n == 0 { 290 return 291 } 292 293 // For zero matrix size the following slice length checks are trivially satisfied. 294 if len(a) < lda*(na-1)+na { 295 panic(shortA) 296 } 297 if len(b) < ldb*(m-1)+n { 298 panic(shortB) 299 } 300 if len(c) < ldc*(m-1)+n { 301 panic(shortC) 302 } 303 304 // Quick return if possible. 305 if alpha == 0 && beta == 1 { 306 return 307 } 308 309 if alpha == 0 { 310 if beta == 0 { 311 for i := 0; i < m; i++ { 312 ci := c[i*ldc : i*ldc+n] 313 for j := range ci { 314 ci[j] = 0 315 } 316 } 317 } else { 318 for i := 0; i < m; i++ { 319 ci := c[i*ldc : i*ldc+n] 320 c128.ScalUnitary(beta, ci) 321 } 322 } 323 return 324 } 325 326 if side == blas.Left { 327 // Form C = alpha*A*B + beta*C. 328 for i := 0; i < m; i++ { 329 atmp := alpha * complex(real(a[i*lda+i]), 0) 330 bi := b[i*ldb : i*ldb+n] 331 ci := c[i*ldc : i*ldc+n] 332 if beta == 0 { 333 for j, bij := range bi { 334 ci[j] = atmp * bij 335 } 336 } else { 337 for j, bij := range bi { 338 ci[j] = atmp*bij + beta*ci[j] 339 } 340 } 341 if uplo == blas.Upper { 342 for k := 0; k < i; k++ { 343 atmp = alpha * cmplx.Conj(a[k*lda+i]) 344 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 345 } 346 for k := i + 1; k < m; k++ { 347 atmp = alpha * a[i*lda+k] 348 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 349 } 350 } else { 351 for k := 0; k < i; k++ { 352 atmp = alpha * a[i*lda+k] 353 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 354 } 355 for k := i + 1; k < m; k++ { 356 atmp = alpha * cmplx.Conj(a[k*lda+i]) 357 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 358 } 359 } 360 } 361 } else { 362 // Form C = alpha*B*A + beta*C. 363 if uplo == blas.Upper { 364 for i := 0; i < m; i++ { 365 for j := n - 1; j >= 0; j-- { 366 abij := alpha * b[i*ldb+j] 367 aj := a[j*lda+j+1 : j*lda+n] 368 bi := b[i*ldb+j+1 : i*ldb+n] 369 ci := c[i*ldc+j+1 : i*ldc+n] 370 var tmp complex128 371 for k, ajk := range aj { 372 ci[k] += abij * ajk 373 tmp += bi[k] * cmplx.Conj(ajk) 374 } 375 ajj := complex(real(a[j*lda+j]), 0) 376 if beta == 0 { 377 c[i*ldc+j] = abij*ajj + alpha*tmp 378 } else { 379 c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j] 380 } 381 } 382 } 383 } else { 384 for i := 0; i < m; i++ { 385 for j := 0; j < n; j++ { 386 abij := alpha * b[i*ldb+j] 387 aj := a[j*lda : j*lda+j] 388 bi := b[i*ldb : i*ldb+j] 389 ci := c[i*ldc : i*ldc+j] 390 var tmp complex128 391 for k, ajk := range aj { 392 ci[k] += abij * ajk 393 tmp += bi[k] * cmplx.Conj(ajk) 394 } 395 ajj := complex(real(a[j*lda+j]), 0) 396 if beta == 0 { 397 c[i*ldc+j] = abij*ajj + alpha*tmp 398 } else { 399 c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j] 400 } 401 } 402 } 403 } 404 } 405 } 406 407 // Zherk performs one of the hermitian rank-k operations 408 // C = alpha*A*Aᴴ + beta*C if trans == blas.NoTrans 409 // C = alpha*Aᴴ*A + beta*C if trans == blas.ConjTrans 410 // where alpha and beta are real scalars, C is an n×n hermitian matrix and A is 411 // an n×k matrix in the first case and a k×n matrix in the second case. 412 // 413 // The imaginary parts of the diagonal elements of C are assumed to be zero, and 414 // on return they will be set to zero. 415 func (Implementation) Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int) { 416 var rowA, colA int 417 switch trans { 418 default: 419 panic(badTranspose) 420 case blas.NoTrans: 421 rowA, colA = n, k 422 case blas.ConjTrans: 423 rowA, colA = k, n 424 } 425 switch { 426 case uplo != blas.Lower && uplo != blas.Upper: 427 panic(badUplo) 428 case n < 0: 429 panic(nLT0) 430 case k < 0: 431 panic(kLT0) 432 case lda < max(1, colA): 433 panic(badLdA) 434 case ldc < max(1, n): 435 panic(badLdC) 436 } 437 438 // Quick return if possible. 439 if n == 0 { 440 return 441 } 442 443 // For zero matrix size the following slice length checks are trivially satisfied. 444 if len(a) < (rowA-1)*lda+colA { 445 panic(shortA) 446 } 447 if len(c) < (n-1)*ldc+n { 448 panic(shortC) 449 } 450 451 // Quick return if possible. 452 if (alpha == 0 || k == 0) && beta == 1 { 453 return 454 } 455 456 if alpha == 0 { 457 if uplo == blas.Upper { 458 if beta == 0 { 459 for i := 0; i < n; i++ { 460 ci := c[i*ldc+i : i*ldc+n] 461 for j := range ci { 462 ci[j] = 0 463 } 464 } 465 } else { 466 for i := 0; i < n; i++ { 467 ci := c[i*ldc+i : i*ldc+n] 468 ci[0] = complex(beta*real(ci[0]), 0) 469 if i != n-1 { 470 c128.DscalUnitary(beta, ci[1:]) 471 } 472 } 473 } 474 } else { 475 if beta == 0 { 476 for i := 0; i < n; i++ { 477 ci := c[i*ldc : i*ldc+i+1] 478 for j := range ci { 479 ci[j] = 0 480 } 481 } 482 } else { 483 for i := 0; i < n; i++ { 484 ci := c[i*ldc : i*ldc+i+1] 485 if i != 0 { 486 c128.DscalUnitary(beta, ci[:i]) 487 } 488 ci[i] = complex(beta*real(ci[i]), 0) 489 } 490 } 491 } 492 return 493 } 494 495 calpha := complex(alpha, 0) 496 if trans == blas.NoTrans { 497 // Form C = alpha*A*Aᴴ + beta*C. 498 cbeta := complex(beta, 0) 499 if uplo == blas.Upper { 500 for i := 0; i < n; i++ { 501 ci := c[i*ldc+i : i*ldc+n] 502 ai := a[i*lda : i*lda+k] 503 switch { 504 case beta == 0: 505 // Handle the i-th diagonal element of C. 506 ci[0] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0) 507 // Handle the remaining elements on the i-th row of C. 508 for jc := range ci[1:] { 509 j := i + 1 + jc 510 ci[jc+1] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai) 511 } 512 case beta != 1: 513 cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[0] 514 ci[0] = complex(real(cii), 0) 515 for jc, cij := range ci[1:] { 516 j := i + 1 + jc 517 ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij 518 } 519 default: 520 cii := calpha*c128.DotcUnitary(ai, ai) + ci[0] 521 ci[0] = complex(real(cii), 0) 522 for jc, cij := range ci[1:] { 523 j := i + 1 + jc 524 ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij 525 } 526 } 527 } 528 } else { 529 for i := 0; i < n; i++ { 530 ci := c[i*ldc : i*ldc+i+1] 531 ai := a[i*lda : i*lda+k] 532 switch { 533 case beta == 0: 534 // Handle the first i-1 elements on the i-th row of C. 535 for j := range ci[:i] { 536 ci[j] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai) 537 } 538 // Handle the i-th diagonal element of C. 539 ci[i] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0) 540 case beta != 1: 541 for j, cij := range ci[:i] { 542 ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij 543 } 544 cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[i] 545 ci[i] = complex(real(cii), 0) 546 default: 547 for j, cij := range ci[:i] { 548 ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij 549 } 550 cii := calpha*c128.DotcUnitary(ai, ai) + ci[i] 551 ci[i] = complex(real(cii), 0) 552 } 553 } 554 } 555 } else { 556 // Form C = alpha*Aᴴ*A + beta*C. 557 if uplo == blas.Upper { 558 for i := 0; i < n; i++ { 559 ci := c[i*ldc+i : i*ldc+n] 560 switch { 561 case beta == 0: 562 for jc := range ci { 563 ci[jc] = 0 564 } 565 case beta != 1: 566 c128.DscalUnitary(beta, ci) 567 ci[0] = complex(real(ci[0]), 0) 568 default: 569 ci[0] = complex(real(ci[0]), 0) 570 } 571 for j := 0; j < k; j++ { 572 aji := cmplx.Conj(a[j*lda+i]) 573 if aji != 0 { 574 c128.AxpyUnitary(calpha*aji, a[j*lda+i:j*lda+n], ci) 575 } 576 } 577 c[i*ldc+i] = complex(real(c[i*ldc+i]), 0) 578 } 579 } else { 580 for i := 0; i < n; i++ { 581 ci := c[i*ldc : i*ldc+i+1] 582 switch { 583 case beta == 0: 584 for j := range ci { 585 ci[j] = 0 586 } 587 case beta != 1: 588 c128.DscalUnitary(beta, ci) 589 ci[i] = complex(real(ci[i]), 0) 590 default: 591 ci[i] = complex(real(ci[i]), 0) 592 } 593 for j := 0; j < k; j++ { 594 aji := cmplx.Conj(a[j*lda+i]) 595 if aji != 0 { 596 c128.AxpyUnitary(calpha*aji, a[j*lda:j*lda+i+1], ci) 597 } 598 } 599 c[i*ldc+i] = complex(real(c[i*ldc+i]), 0) 600 } 601 } 602 } 603 } 604 605 // Zher2k performs one of the hermitian rank-2k operations 606 // C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C if trans == blas.NoTrans 607 // C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C if trans == blas.ConjTrans 608 // where alpha and beta are scalars with beta real, C is an n×n hermitian matrix 609 // and A and B are n×k matrices in the first case and k×n matrices in the second case. 610 // 611 // The imaginary parts of the diagonal elements of C are assumed to be zero, and 612 // on return they will be set to zero. 613 func (Implementation) Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) { 614 var row, col int 615 switch trans { 616 default: 617 panic(badTranspose) 618 case blas.NoTrans: 619 row, col = n, k 620 case blas.ConjTrans: 621 row, col = k, n 622 } 623 switch { 624 case uplo != blas.Lower && uplo != blas.Upper: 625 panic(badUplo) 626 case n < 0: 627 panic(nLT0) 628 case k < 0: 629 panic(kLT0) 630 case lda < max(1, col): 631 panic(badLdA) 632 case ldb < max(1, col): 633 panic(badLdB) 634 case ldc < max(1, n): 635 panic(badLdC) 636 } 637 638 // Quick return if possible. 639 if n == 0 { 640 return 641 } 642 643 // For zero matrix size the following slice length checks are trivially satisfied. 644 if len(a) < (row-1)*lda+col { 645 panic(shortA) 646 } 647 if len(b) < (row-1)*ldb+col { 648 panic(shortB) 649 } 650 if len(c) < (n-1)*ldc+n { 651 panic(shortC) 652 } 653 654 // Quick return if possible. 655 if (alpha == 0 || k == 0) && beta == 1 { 656 return 657 } 658 659 if alpha == 0 { 660 if uplo == blas.Upper { 661 if beta == 0 { 662 for i := 0; i < n; i++ { 663 ci := c[i*ldc+i : i*ldc+n] 664 for j := range ci { 665 ci[j] = 0 666 } 667 } 668 } else { 669 for i := 0; i < n; i++ { 670 ci := c[i*ldc+i : i*ldc+n] 671 ci[0] = complex(beta*real(ci[0]), 0) 672 if i != n-1 { 673 c128.DscalUnitary(beta, ci[1:]) 674 } 675 } 676 } 677 } else { 678 if beta == 0 { 679 for i := 0; i < n; i++ { 680 ci := c[i*ldc : i*ldc+i+1] 681 for j := range ci { 682 ci[j] = 0 683 } 684 } 685 } else { 686 for i := 0; i < n; i++ { 687 ci := c[i*ldc : i*ldc+i+1] 688 if i != 0 { 689 c128.DscalUnitary(beta, ci[:i]) 690 } 691 ci[i] = complex(beta*real(ci[i]), 0) 692 } 693 } 694 } 695 return 696 } 697 698 conjalpha := cmplx.Conj(alpha) 699 cbeta := complex(beta, 0) 700 if trans == blas.NoTrans { 701 // Form C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C. 702 if uplo == blas.Upper { 703 for i := 0; i < n; i++ { 704 ci := c[i*ldc+i+1 : i*ldc+n] 705 ai := a[i*lda : i*lda+k] 706 bi := b[i*ldb : i*ldb+k] 707 if beta == 0 { 708 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) 709 c[i*ldc+i] = complex(real(cii), 0) 710 for jc := range ci { 711 j := i + 1 + jc 712 ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) 713 } 714 } else { 715 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i] 716 c[i*ldc+i] = complex(real(cii), 0) 717 for jc, cij := range ci { 718 j := i + 1 + jc 719 ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij 720 } 721 } 722 } 723 } else { 724 for i := 0; i < n; i++ { 725 ci := c[i*ldc : i*ldc+i] 726 ai := a[i*lda : i*lda+k] 727 bi := b[i*ldb : i*ldb+k] 728 if beta == 0 { 729 for j := range ci { 730 ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) 731 } 732 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) 733 c[i*ldc+i] = complex(real(cii), 0) 734 } else { 735 for j, cij := range ci { 736 ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij 737 } 738 cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i] 739 c[i*ldc+i] = complex(real(cii), 0) 740 } 741 } 742 } 743 } else { 744 // Form C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C. 745 if uplo == blas.Upper { 746 for i := 0; i < n; i++ { 747 ci := c[i*ldc+i : i*ldc+n] 748 switch { 749 case beta == 0: 750 for jc := range ci { 751 ci[jc] = 0 752 } 753 case beta != 1: 754 c128.DscalUnitary(beta, ci) 755 ci[0] = complex(real(ci[0]), 0) 756 default: 757 ci[0] = complex(real(ci[0]), 0) 758 } 759 for j := 0; j < k; j++ { 760 aji := a[j*lda+i] 761 bji := b[j*ldb+i] 762 if aji != 0 { 763 c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb+i:j*ldb+n], ci) 764 } 765 if bji != 0 { 766 c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda+i:j*lda+n], ci) 767 } 768 } 769 ci[0] = complex(real(ci[0]), 0) 770 } 771 } else { 772 for i := 0; i < n; i++ { 773 ci := c[i*ldc : i*ldc+i+1] 774 switch { 775 case beta == 0: 776 for j := range ci { 777 ci[j] = 0 778 } 779 case beta != 1: 780 c128.DscalUnitary(beta, ci) 781 ci[i] = complex(real(ci[i]), 0) 782 default: 783 ci[i] = complex(real(ci[i]), 0) 784 } 785 for j := 0; j < k; j++ { 786 aji := a[j*lda+i] 787 bji := b[j*ldb+i] 788 if aji != 0 { 789 c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb:j*ldb+i+1], ci) 790 } 791 if bji != 0 { 792 c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda:j*lda+i+1], ci) 793 } 794 } 795 ci[i] = complex(real(ci[i]), 0) 796 } 797 } 798 } 799 } 800 801 // Zsymm performs one of the matrix-matrix operations 802 // C = alpha*A*B + beta*C if side == blas.Left 803 // C = alpha*B*A + beta*C if side == blas.Right 804 // where alpha and beta are scalars, A is an m×m or n×n symmetric matrix and B 805 // and C are m×n matrices. 806 func (Implementation) Zsymm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 807 na := m 808 if side == blas.Right { 809 na = n 810 } 811 switch { 812 case side != blas.Left && side != blas.Right: 813 panic(badSide) 814 case uplo != blas.Lower && uplo != blas.Upper: 815 panic(badUplo) 816 case m < 0: 817 panic(mLT0) 818 case n < 0: 819 panic(nLT0) 820 case lda < max(1, na): 821 panic(badLdA) 822 case ldb < max(1, n): 823 panic(badLdB) 824 case ldc < max(1, n): 825 panic(badLdC) 826 } 827 828 // Quick return if possible. 829 if m == 0 || n == 0 { 830 return 831 } 832 833 // For zero matrix size the following slice length checks are trivially satisfied. 834 if len(a) < lda*(na-1)+na { 835 panic(shortA) 836 } 837 if len(b) < ldb*(m-1)+n { 838 panic(shortB) 839 } 840 if len(c) < ldc*(m-1)+n { 841 panic(shortC) 842 } 843 844 // Quick return if possible. 845 if alpha == 0 && beta == 1 { 846 return 847 } 848 849 if alpha == 0 { 850 if beta == 0 { 851 for i := 0; i < m; i++ { 852 ci := c[i*ldc : i*ldc+n] 853 for j := range ci { 854 ci[j] = 0 855 } 856 } 857 } else { 858 for i := 0; i < m; i++ { 859 ci := c[i*ldc : i*ldc+n] 860 c128.ScalUnitary(beta, ci) 861 } 862 } 863 return 864 } 865 866 if side == blas.Left { 867 // Form C = alpha*A*B + beta*C. 868 for i := 0; i < m; i++ { 869 atmp := alpha * a[i*lda+i] 870 bi := b[i*ldb : i*ldb+n] 871 ci := c[i*ldc : i*ldc+n] 872 if beta == 0 { 873 for j, bij := range bi { 874 ci[j] = atmp * bij 875 } 876 } else { 877 for j, bij := range bi { 878 ci[j] = atmp*bij + beta*ci[j] 879 } 880 } 881 if uplo == blas.Upper { 882 for k := 0; k < i; k++ { 883 atmp = alpha * a[k*lda+i] 884 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 885 } 886 for k := i + 1; k < m; k++ { 887 atmp = alpha * a[i*lda+k] 888 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 889 } 890 } else { 891 for k := 0; k < i; k++ { 892 atmp = alpha * a[i*lda+k] 893 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 894 } 895 for k := i + 1; k < m; k++ { 896 atmp = alpha * a[k*lda+i] 897 c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci) 898 } 899 } 900 } 901 } else { 902 // Form C = alpha*B*A + beta*C. 903 if uplo == blas.Upper { 904 for i := 0; i < m; i++ { 905 for j := n - 1; j >= 0; j-- { 906 abij := alpha * b[i*ldb+j] 907 aj := a[j*lda+j+1 : j*lda+n] 908 bi := b[i*ldb+j+1 : i*ldb+n] 909 ci := c[i*ldc+j+1 : i*ldc+n] 910 var tmp complex128 911 for k, ajk := range aj { 912 ci[k] += abij * ajk 913 tmp += bi[k] * ajk 914 } 915 if beta == 0 { 916 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp 917 } else { 918 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j] 919 } 920 } 921 } 922 } else { 923 for i := 0; i < m; i++ { 924 for j := 0; j < n; j++ { 925 abij := alpha * b[i*ldb+j] 926 aj := a[j*lda : j*lda+j] 927 bi := b[i*ldb : i*ldb+j] 928 ci := c[i*ldc : i*ldc+j] 929 var tmp complex128 930 for k, ajk := range aj { 931 ci[k] += abij * ajk 932 tmp += bi[k] * ajk 933 } 934 if beta == 0 { 935 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp 936 } else { 937 c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j] 938 } 939 } 940 } 941 } 942 } 943 } 944 945 // Zsyrk performs one of the symmetric rank-k operations 946 // C = alpha*A*Aᵀ + beta*C if trans == blas.NoTrans 947 // C = alpha*Aᵀ*A + beta*C if trans == blas.Trans 948 // where alpha and beta are scalars, C is an n×n symmetric matrix and A is 949 // an n×k matrix in the first case and a k×n matrix in the second case. 950 func (Implementation) Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) { 951 var rowA, colA int 952 switch trans { 953 default: 954 panic(badTranspose) 955 case blas.NoTrans: 956 rowA, colA = n, k 957 case blas.Trans: 958 rowA, colA = k, n 959 } 960 switch { 961 case uplo != blas.Lower && uplo != blas.Upper: 962 panic(badUplo) 963 case n < 0: 964 panic(nLT0) 965 case k < 0: 966 panic(kLT0) 967 case lda < max(1, colA): 968 panic(badLdA) 969 case ldc < max(1, n): 970 panic(badLdC) 971 } 972 973 // Quick return if possible. 974 if n == 0 { 975 return 976 } 977 978 // For zero matrix size the following slice length checks are trivially satisfied. 979 if len(a) < (rowA-1)*lda+colA { 980 panic(shortA) 981 } 982 if len(c) < (n-1)*ldc+n { 983 panic(shortC) 984 } 985 986 // Quick return if possible. 987 if (alpha == 0 || k == 0) && beta == 1 { 988 return 989 } 990 991 if alpha == 0 { 992 if uplo == blas.Upper { 993 if beta == 0 { 994 for i := 0; i < n; i++ { 995 ci := c[i*ldc+i : i*ldc+n] 996 for j := range ci { 997 ci[j] = 0 998 } 999 } 1000 } else { 1001 for i := 0; i < n; i++ { 1002 ci := c[i*ldc+i : i*ldc+n] 1003 c128.ScalUnitary(beta, ci) 1004 } 1005 } 1006 } else { 1007 if beta == 0 { 1008 for i := 0; i < n; i++ { 1009 ci := c[i*ldc : i*ldc+i+1] 1010 for j := range ci { 1011 ci[j] = 0 1012 } 1013 } 1014 } else { 1015 for i := 0; i < n; i++ { 1016 ci := c[i*ldc : i*ldc+i+1] 1017 c128.ScalUnitary(beta, ci) 1018 } 1019 } 1020 } 1021 return 1022 } 1023 1024 if trans == blas.NoTrans { 1025 // Form C = alpha*A*Aᵀ + beta*C. 1026 if uplo == blas.Upper { 1027 for i := 0; i < n; i++ { 1028 ci := c[i*ldc+i : i*ldc+n] 1029 ai := a[i*lda : i*lda+k] 1030 if beta == 0 { 1031 for jc := range ci { 1032 j := i + jc 1033 ci[jc] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k]) 1034 } 1035 } else { 1036 for jc, cij := range ci { 1037 j := i + jc 1038 ci[jc] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k]) 1039 } 1040 } 1041 } 1042 } else { 1043 for i := 0; i < n; i++ { 1044 ci := c[i*ldc : i*ldc+i+1] 1045 ai := a[i*lda : i*lda+k] 1046 if beta == 0 { 1047 for j := range ci { 1048 ci[j] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k]) 1049 } 1050 } else { 1051 for j, cij := range ci { 1052 ci[j] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k]) 1053 } 1054 } 1055 } 1056 } 1057 } else { 1058 // Form C = alpha*Aᵀ*A + beta*C. 1059 if uplo == blas.Upper { 1060 for i := 0; i < n; i++ { 1061 ci := c[i*ldc+i : i*ldc+n] 1062 switch { 1063 case beta == 0: 1064 for jc := range ci { 1065 ci[jc] = 0 1066 } 1067 case beta != 1: 1068 for jc := range ci { 1069 ci[jc] *= beta 1070 } 1071 } 1072 for j := 0; j < k; j++ { 1073 aji := a[j*lda+i] 1074 if aji != 0 { 1075 c128.AxpyUnitary(alpha*aji, a[j*lda+i:j*lda+n], ci) 1076 } 1077 } 1078 } 1079 } else { 1080 for i := 0; i < n; i++ { 1081 ci := c[i*ldc : i*ldc+i+1] 1082 switch { 1083 case beta == 0: 1084 for j := range ci { 1085 ci[j] = 0 1086 } 1087 case beta != 1: 1088 for j := range ci { 1089 ci[j] *= beta 1090 } 1091 } 1092 for j := 0; j < k; j++ { 1093 aji := a[j*lda+i] 1094 if aji != 0 { 1095 c128.AxpyUnitary(alpha*aji, a[j*lda:j*lda+i+1], ci) 1096 } 1097 } 1098 } 1099 } 1100 } 1101 } 1102 1103 // Zsyr2k performs one of the symmetric rank-2k operations 1104 // C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C if trans == blas.NoTrans 1105 // C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C if trans == blas.Trans 1106 // where alpha and beta are scalars, C is an n×n symmetric matrix and A and B 1107 // are n×k matrices in the first case and k×n matrices in the second case. 1108 func (Implementation) Zsyr2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) { 1109 var row, col int 1110 switch trans { 1111 default: 1112 panic(badTranspose) 1113 case blas.NoTrans: 1114 row, col = n, k 1115 case blas.Trans: 1116 row, col = k, n 1117 } 1118 switch { 1119 case uplo != blas.Lower && uplo != blas.Upper: 1120 panic(badUplo) 1121 case n < 0: 1122 panic(nLT0) 1123 case k < 0: 1124 panic(kLT0) 1125 case lda < max(1, col): 1126 panic(badLdA) 1127 case ldb < max(1, col): 1128 panic(badLdB) 1129 case ldc < max(1, n): 1130 panic(badLdC) 1131 } 1132 1133 // Quick return if possible. 1134 if n == 0 { 1135 return 1136 } 1137 1138 // For zero matrix size the following slice length checks are trivially satisfied. 1139 if len(a) < (row-1)*lda+col { 1140 panic(shortA) 1141 } 1142 if len(b) < (row-1)*ldb+col { 1143 panic(shortB) 1144 } 1145 if len(c) < (n-1)*ldc+n { 1146 panic(shortC) 1147 } 1148 1149 // Quick return if possible. 1150 if (alpha == 0 || k == 0) && beta == 1 { 1151 return 1152 } 1153 1154 if alpha == 0 { 1155 if uplo == blas.Upper { 1156 if beta == 0 { 1157 for i := 0; i < n; i++ { 1158 ci := c[i*ldc+i : i*ldc+n] 1159 for j := range ci { 1160 ci[j] = 0 1161 } 1162 } 1163 } else { 1164 for i := 0; i < n; i++ { 1165 ci := c[i*ldc+i : i*ldc+n] 1166 c128.ScalUnitary(beta, ci) 1167 } 1168 } 1169 } else { 1170 if beta == 0 { 1171 for i := 0; i < n; i++ { 1172 ci := c[i*ldc : i*ldc+i+1] 1173 for j := range ci { 1174 ci[j] = 0 1175 } 1176 } 1177 } else { 1178 for i := 0; i < n; i++ { 1179 ci := c[i*ldc : i*ldc+i+1] 1180 c128.ScalUnitary(beta, ci) 1181 } 1182 } 1183 } 1184 return 1185 } 1186 1187 if trans == blas.NoTrans { 1188 // Form C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C. 1189 if uplo == blas.Upper { 1190 for i := 0; i < n; i++ { 1191 ci := c[i*ldc+i : i*ldc+n] 1192 ai := a[i*lda : i*lda+k] 1193 bi := b[i*ldb : i*ldb+k] 1194 if beta == 0 { 1195 for jc := range ci { 1196 j := i + jc 1197 ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) 1198 } 1199 } else { 1200 for jc, cij := range ci { 1201 j := i + jc 1202 ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij 1203 } 1204 } 1205 } 1206 } else { 1207 for i := 0; i < n; i++ { 1208 ci := c[i*ldc : i*ldc+i+1] 1209 ai := a[i*lda : i*lda+k] 1210 bi := b[i*ldb : i*ldb+k] 1211 if beta == 0 { 1212 for j := range ci { 1213 ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) 1214 } 1215 } else { 1216 for j, cij := range ci { 1217 ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij 1218 } 1219 } 1220 } 1221 } 1222 } else { 1223 // Form C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C. 1224 if uplo == blas.Upper { 1225 for i := 0; i < n; i++ { 1226 ci := c[i*ldc+i : i*ldc+n] 1227 switch { 1228 case beta == 0: 1229 for jc := range ci { 1230 ci[jc] = 0 1231 } 1232 case beta != 1: 1233 for jc := range ci { 1234 ci[jc] *= beta 1235 } 1236 } 1237 for j := 0; j < k; j++ { 1238 aji := a[j*lda+i] 1239 bji := b[j*ldb+i] 1240 if aji != 0 { 1241 c128.AxpyUnitary(alpha*aji, b[j*ldb+i:j*ldb+n], ci) 1242 } 1243 if bji != 0 { 1244 c128.AxpyUnitary(alpha*bji, a[j*lda+i:j*lda+n], ci) 1245 } 1246 } 1247 } 1248 } else { 1249 for i := 0; i < n; i++ { 1250 ci := c[i*ldc : i*ldc+i+1] 1251 switch { 1252 case beta == 0: 1253 for j := range ci { 1254 ci[j] = 0 1255 } 1256 case beta != 1: 1257 for j := range ci { 1258 ci[j] *= beta 1259 } 1260 } 1261 for j := 0; j < k; j++ { 1262 aji := a[j*lda+i] 1263 bji := b[j*ldb+i] 1264 if aji != 0 { 1265 c128.AxpyUnitary(alpha*aji, b[j*ldb:j*ldb+i+1], ci) 1266 } 1267 if bji != 0 { 1268 c128.AxpyUnitary(alpha*bji, a[j*lda:j*lda+i+1], ci) 1269 } 1270 } 1271 } 1272 } 1273 } 1274 } 1275 1276 // Ztrmm performs one of the matrix-matrix operations 1277 // B = alpha * op(A) * B if side == blas.Left, 1278 // B = alpha * B * op(A) if side == blas.Right, 1279 // where alpha is a scalar, B is an m×n matrix, A is a unit, or non-unit, 1280 // upper or lower triangular matrix and op(A) is one of 1281 // op(A) = A if trans == blas.NoTrans, 1282 // op(A) = Aᵀ if trans == blas.Trans, 1283 // op(A) = Aᴴ if trans == blas.ConjTrans. 1284 func (Implementation) Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) { 1285 na := m 1286 if side == blas.Right { 1287 na = n 1288 } 1289 switch { 1290 case side != blas.Left && side != blas.Right: 1291 panic(badSide) 1292 case uplo != blas.Lower && uplo != blas.Upper: 1293 panic(badUplo) 1294 case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans: 1295 panic(badTranspose) 1296 case diag != blas.Unit && diag != blas.NonUnit: 1297 panic(badDiag) 1298 case m < 0: 1299 panic(mLT0) 1300 case n < 0: 1301 panic(nLT0) 1302 case lda < max(1, na): 1303 panic(badLdA) 1304 case ldb < max(1, n): 1305 panic(badLdB) 1306 } 1307 1308 // Quick return if possible. 1309 if m == 0 || n == 0 { 1310 return 1311 } 1312 1313 // For zero matrix size the following slice length checks are trivially satisfied. 1314 if len(a) < (na-1)*lda+na { 1315 panic(shortA) 1316 } 1317 if len(b) < (m-1)*ldb+n { 1318 panic(shortB) 1319 } 1320 1321 // Quick return if possible. 1322 if alpha == 0 { 1323 for i := 0; i < m; i++ { 1324 bi := b[i*ldb : i*ldb+n] 1325 for j := range bi { 1326 bi[j] = 0 1327 } 1328 } 1329 return 1330 } 1331 1332 noConj := trans != blas.ConjTrans 1333 noUnit := diag == blas.NonUnit 1334 if side == blas.Left { 1335 if trans == blas.NoTrans { 1336 // Form B = alpha*A*B. 1337 if uplo == blas.Upper { 1338 for i := 0; i < m; i++ { 1339 aii := alpha 1340 if noUnit { 1341 aii *= a[i*lda+i] 1342 } 1343 bi := b[i*ldb : i*ldb+n] 1344 for j := range bi { 1345 bi[j] *= aii 1346 } 1347 for ja, aij := range a[i*lda+i+1 : i*lda+m] { 1348 j := ja + i + 1 1349 if aij != 0 { 1350 c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi) 1351 } 1352 } 1353 } 1354 } else { 1355 for i := m - 1; i >= 0; i-- { 1356 aii := alpha 1357 if noUnit { 1358 aii *= a[i*lda+i] 1359 } 1360 bi := b[i*ldb : i*ldb+n] 1361 for j := range bi { 1362 bi[j] *= aii 1363 } 1364 for j, aij := range a[i*lda : i*lda+i] { 1365 if aij != 0 { 1366 c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi) 1367 } 1368 } 1369 } 1370 } 1371 } else { 1372 // Form B = alpha*Aᵀ*B or B = alpha*Aᴴ*B. 1373 if uplo == blas.Upper { 1374 for k := m - 1; k >= 0; k-- { 1375 bk := b[k*ldb : k*ldb+n] 1376 for ja, ajk := range a[k*lda+k+1 : k*lda+m] { 1377 if ajk == 0 { 1378 continue 1379 } 1380 j := k + 1 + ja 1381 if noConj { 1382 c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n]) 1383 } else { 1384 c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n]) 1385 } 1386 } 1387 akk := alpha 1388 if noUnit { 1389 if noConj { 1390 akk *= a[k*lda+k] 1391 } else { 1392 akk *= cmplx.Conj(a[k*lda+k]) 1393 } 1394 } 1395 if akk != 1 { 1396 c128.ScalUnitary(akk, bk) 1397 } 1398 } 1399 } else { 1400 for k := 0; k < m; k++ { 1401 bk := b[k*ldb : k*ldb+n] 1402 for j, ajk := range a[k*lda : k*lda+k] { 1403 if ajk == 0 { 1404 continue 1405 } 1406 if noConj { 1407 c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n]) 1408 } else { 1409 c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n]) 1410 } 1411 } 1412 akk := alpha 1413 if noUnit { 1414 if noConj { 1415 akk *= a[k*lda+k] 1416 } else { 1417 akk *= cmplx.Conj(a[k*lda+k]) 1418 } 1419 } 1420 if akk != 1 { 1421 c128.ScalUnitary(akk, bk) 1422 } 1423 } 1424 } 1425 } 1426 } else { 1427 if trans == blas.NoTrans { 1428 // Form B = alpha*B*A. 1429 if uplo == blas.Upper { 1430 for i := 0; i < m; i++ { 1431 bi := b[i*ldb : i*ldb+n] 1432 for k := n - 1; k >= 0; k-- { 1433 abik := alpha * bi[k] 1434 if abik == 0 { 1435 continue 1436 } 1437 bi[k] = abik 1438 if noUnit { 1439 bi[k] *= a[k*lda+k] 1440 } 1441 c128.AxpyUnitary(abik, a[k*lda+k+1:k*lda+n], bi[k+1:]) 1442 } 1443 } 1444 } else { 1445 for i := 0; i < m; i++ { 1446 bi := b[i*ldb : i*ldb+n] 1447 for k := 0; k < n; k++ { 1448 abik := alpha * bi[k] 1449 if abik == 0 { 1450 continue 1451 } 1452 bi[k] = abik 1453 if noUnit { 1454 bi[k] *= a[k*lda+k] 1455 } 1456 c128.AxpyUnitary(abik, a[k*lda:k*lda+k], bi[:k]) 1457 } 1458 } 1459 } 1460 } else { 1461 // Form B = alpha*B*Aᵀ or B = alpha*B*Aᴴ. 1462 if uplo == blas.Upper { 1463 for i := 0; i < m; i++ { 1464 bi := b[i*ldb : i*ldb+n] 1465 for j, bij := range bi { 1466 if noConj { 1467 if noUnit { 1468 bij *= a[j*lda+j] 1469 } 1470 bij += c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1471 } else { 1472 if noUnit { 1473 bij *= cmplx.Conj(a[j*lda+j]) 1474 } 1475 bij += c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1476 } 1477 bi[j] = alpha * bij 1478 } 1479 } 1480 } else { 1481 for i := 0; i < m; i++ { 1482 bi := b[i*ldb : i*ldb+n] 1483 for j := n - 1; j >= 0; j-- { 1484 bij := bi[j] 1485 if noConj { 1486 if noUnit { 1487 bij *= a[j*lda+j] 1488 } 1489 bij += c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j]) 1490 } else { 1491 if noUnit { 1492 bij *= cmplx.Conj(a[j*lda+j]) 1493 } 1494 bij += c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j]) 1495 } 1496 bi[j] = alpha * bij 1497 } 1498 } 1499 } 1500 } 1501 } 1502 } 1503 1504 // Ztrsm solves one of the matrix equations 1505 // op(A) * X = alpha * B if side == blas.Left, 1506 // X * op(A) = alpha * B if side == blas.Right, 1507 // where alpha is a scalar, X and B are m×n matrices, A is a unit or 1508 // non-unit, upper or lower triangular matrix and op(A) is one of 1509 // op(A) = A if transA == blas.NoTrans, 1510 // op(A) = Aᵀ if transA == blas.Trans, 1511 // op(A) = Aᴴ if transA == blas.ConjTrans. 1512 // On return the matrix X is overwritten on B. 1513 func (Implementation) Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) { 1514 na := m 1515 if side == blas.Right { 1516 na = n 1517 } 1518 switch { 1519 case side != blas.Left && side != blas.Right: 1520 panic(badSide) 1521 case uplo != blas.Lower && uplo != blas.Upper: 1522 panic(badUplo) 1523 case transA != blas.NoTrans && transA != blas.Trans && transA != blas.ConjTrans: 1524 panic(badTranspose) 1525 case diag != blas.Unit && diag != blas.NonUnit: 1526 panic(badDiag) 1527 case m < 0: 1528 panic(mLT0) 1529 case n < 0: 1530 panic(nLT0) 1531 case lda < max(1, na): 1532 panic(badLdA) 1533 case ldb < max(1, n): 1534 panic(badLdB) 1535 } 1536 1537 // Quick return if possible. 1538 if m == 0 || n == 0 { 1539 return 1540 } 1541 1542 // For zero matrix size the following slice length checks are trivially satisfied. 1543 if len(a) < (na-1)*lda+na { 1544 panic(shortA) 1545 } 1546 if len(b) < (m-1)*ldb+n { 1547 panic(shortB) 1548 } 1549 1550 if alpha == 0 { 1551 for i := 0; i < m; i++ { 1552 for j := 0; j < n; j++ { 1553 b[i*ldb+j] = 0 1554 } 1555 } 1556 return 1557 } 1558 1559 noConj := transA != blas.ConjTrans 1560 noUnit := diag == blas.NonUnit 1561 if side == blas.Left { 1562 if transA == blas.NoTrans { 1563 // Form B = alpha*inv(A)*B. 1564 if uplo == blas.Upper { 1565 for i := m - 1; i >= 0; i-- { 1566 bi := b[i*ldb : i*ldb+n] 1567 if alpha != 1 { 1568 c128.ScalUnitary(alpha, bi) 1569 } 1570 for ka, aik := range a[i*lda+i+1 : i*lda+m] { 1571 k := i + 1 + ka 1572 if aik != 0 { 1573 c128.AxpyUnitary(-aik, b[k*ldb:k*ldb+n], bi) 1574 } 1575 } 1576 if noUnit { 1577 c128.ScalUnitary(1/a[i*lda+i], bi) 1578 } 1579 } 1580 } else { 1581 for i := 0; i < m; i++ { 1582 bi := b[i*ldb : i*ldb+n] 1583 if alpha != 1 { 1584 c128.ScalUnitary(alpha, bi) 1585 } 1586 for j, aij := range a[i*lda : i*lda+i] { 1587 if aij != 0 { 1588 c128.AxpyUnitary(-aij, b[j*ldb:j*ldb+n], bi) 1589 } 1590 } 1591 if noUnit { 1592 c128.ScalUnitary(1/a[i*lda+i], bi) 1593 } 1594 } 1595 } 1596 } else { 1597 // Form B = alpha*inv(Aᵀ)*B or B = alpha*inv(Aᴴ)*B. 1598 if uplo == blas.Upper { 1599 for i := 0; i < m; i++ { 1600 bi := b[i*ldb : i*ldb+n] 1601 if noUnit { 1602 if noConj { 1603 c128.ScalUnitary(1/a[i*lda+i], bi) 1604 } else { 1605 c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi) 1606 } 1607 } 1608 for ja, aij := range a[i*lda+i+1 : i*lda+m] { 1609 if aij == 0 { 1610 continue 1611 } 1612 j := i + 1 + ja 1613 if noConj { 1614 c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n]) 1615 } else { 1616 c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n]) 1617 } 1618 } 1619 if alpha != 1 { 1620 c128.ScalUnitary(alpha, bi) 1621 } 1622 } 1623 } else { 1624 for i := m - 1; i >= 0; i-- { 1625 bi := b[i*ldb : i*ldb+n] 1626 if noUnit { 1627 if noConj { 1628 c128.ScalUnitary(1/a[i*lda+i], bi) 1629 } else { 1630 c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi) 1631 } 1632 } 1633 for j, aij := range a[i*lda : i*lda+i] { 1634 if aij == 0 { 1635 continue 1636 } 1637 if noConj { 1638 c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n]) 1639 } else { 1640 c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n]) 1641 } 1642 } 1643 if alpha != 1 { 1644 c128.ScalUnitary(alpha, bi) 1645 } 1646 } 1647 } 1648 } 1649 } else { 1650 if transA == blas.NoTrans { 1651 // Form B = alpha*B*inv(A). 1652 if uplo == blas.Upper { 1653 for i := 0; i < m; i++ { 1654 bi := b[i*ldb : i*ldb+n] 1655 if alpha != 1 { 1656 c128.ScalUnitary(alpha, bi) 1657 } 1658 for j, bij := range bi { 1659 if bij == 0 { 1660 continue 1661 } 1662 if noUnit { 1663 bi[j] /= a[j*lda+j] 1664 } 1665 c128.AxpyUnitary(-bi[j], a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1666 } 1667 } 1668 } else { 1669 for i := 0; i < m; i++ { 1670 bi := b[i*ldb : i*ldb+n] 1671 if alpha != 1 { 1672 c128.ScalUnitary(alpha, bi) 1673 } 1674 for j := n - 1; j >= 0; j-- { 1675 if bi[j] == 0 { 1676 continue 1677 } 1678 if noUnit { 1679 bi[j] /= a[j*lda+j] 1680 } 1681 c128.AxpyUnitary(-bi[j], a[j*lda:j*lda+j], bi[:j]) 1682 } 1683 } 1684 } 1685 } else { 1686 // Form B = alpha*B*inv(Aᵀ) or B = alpha*B*inv(Aᴴ). 1687 if uplo == blas.Upper { 1688 for i := 0; i < m; i++ { 1689 bi := b[i*ldb : i*ldb+n] 1690 for j := n - 1; j >= 0; j-- { 1691 bij := alpha * bi[j] 1692 if noConj { 1693 bij -= c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1694 if noUnit { 1695 bij /= a[j*lda+j] 1696 } 1697 } else { 1698 bij -= c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n]) 1699 if noUnit { 1700 bij /= cmplx.Conj(a[j*lda+j]) 1701 } 1702 } 1703 bi[j] = bij 1704 } 1705 } 1706 } else { 1707 for i := 0; i < m; i++ { 1708 bi := b[i*ldb : i*ldb+n] 1709 for j, bij := range bi { 1710 bij *= alpha 1711 if noConj { 1712 bij -= c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j]) 1713 if noUnit { 1714 bij /= a[j*lda+j] 1715 } 1716 } else { 1717 bij -= c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j]) 1718 if noUnit { 1719 bij /= cmplx.Conj(a[j*lda+j]) 1720 } 1721 } 1722 bi[j] = bij 1723 } 1724 } 1725 } 1726 } 1727 } 1728 }