github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/dense_arithmetic.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "github.com/jingcheng-WU/gonum/blas" 11 "github.com/jingcheng-WU/gonum/blas/blas64" 12 "github.com/jingcheng-WU/gonum/lapack/lapack64" 13 ) 14 15 // Add adds a and b element-wise, placing the result in the receiver. Add 16 // will panic if the two matrices do not have the same shape. 17 func (m *Dense) Add(a, b Matrix) { 18 ar, ac := a.Dims() 19 br, bc := b.Dims() 20 if ar != br || ac != bc { 21 panic(ErrShape) 22 } 23 24 aU, aTrans := untransposeExtract(a) 25 bU, bTrans := untransposeExtract(b) 26 m.reuseAsNonZeroed(ar, ac) 27 28 if arm, ok := a.(*Dense); ok { 29 if brm, ok := b.(*Dense); ok { 30 amat, bmat := arm.mat, brm.mat 31 if m != aU { 32 m.checkOverlap(amat) 33 } 34 if m != bU { 35 m.checkOverlap(bmat) 36 } 37 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride { 38 for i, v := range amat.Data[ja : ja+ac] { 39 m.mat.Data[i+jm] = v + bmat.Data[i+jb] 40 } 41 } 42 return 43 } 44 } 45 46 m.checkOverlapMatrix(aU) 47 m.checkOverlapMatrix(bU) 48 var restore func() 49 if aTrans && m == aU { 50 m, restore = m.isolatedWorkspace(aU) 51 defer restore() 52 } else if bTrans && m == bU { 53 m, restore = m.isolatedWorkspace(bU) 54 defer restore() 55 } 56 57 for r := 0; r < ar; r++ { 58 for c := 0; c < ac; c++ { 59 m.set(r, c, a.At(r, c)+b.At(r, c)) 60 } 61 } 62 } 63 64 // Sub subtracts the matrix b from a, placing the result in the receiver. Sub 65 // will panic if the two matrices do not have the same shape. 66 func (m *Dense) Sub(a, b Matrix) { 67 ar, ac := a.Dims() 68 br, bc := b.Dims() 69 if ar != br || ac != bc { 70 panic(ErrShape) 71 } 72 73 aU, aTrans := untransposeExtract(a) 74 bU, bTrans := untransposeExtract(b) 75 m.reuseAsNonZeroed(ar, ac) 76 77 if arm, ok := a.(*Dense); ok { 78 if brm, ok := b.(*Dense); ok { 79 amat, bmat := arm.mat, brm.mat 80 if m != aU { 81 m.checkOverlap(amat) 82 } 83 if m != bU { 84 m.checkOverlap(bmat) 85 } 86 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride { 87 for i, v := range amat.Data[ja : ja+ac] { 88 m.mat.Data[i+jm] = v - bmat.Data[i+jb] 89 } 90 } 91 return 92 } 93 } 94 95 m.checkOverlapMatrix(aU) 96 m.checkOverlapMatrix(bU) 97 var restore func() 98 if aTrans && m == aU { 99 m, restore = m.isolatedWorkspace(aU) 100 defer restore() 101 } else if bTrans && m == bU { 102 m, restore = m.isolatedWorkspace(bU) 103 defer restore() 104 } 105 106 for r := 0; r < ar; r++ { 107 for c := 0; c < ac; c++ { 108 m.set(r, c, a.At(r, c)-b.At(r, c)) 109 } 110 } 111 } 112 113 // MulElem performs element-wise multiplication of a and b, placing the result 114 // in the receiver. MulElem will panic if the two matrices do not have the same 115 // shape. 116 func (m *Dense) MulElem(a, b Matrix) { 117 ar, ac := a.Dims() 118 br, bc := b.Dims() 119 if ar != br || ac != bc { 120 panic(ErrShape) 121 } 122 123 aU, aTrans := untransposeExtract(a) 124 bU, bTrans := untransposeExtract(b) 125 m.reuseAsNonZeroed(ar, ac) 126 127 if arm, ok := a.(*Dense); ok { 128 if brm, ok := b.(*Dense); ok { 129 amat, bmat := arm.mat, brm.mat 130 if m != aU { 131 m.checkOverlap(amat) 132 } 133 if m != bU { 134 m.checkOverlap(bmat) 135 } 136 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride { 137 for i, v := range amat.Data[ja : ja+ac] { 138 m.mat.Data[i+jm] = v * bmat.Data[i+jb] 139 } 140 } 141 return 142 } 143 } 144 145 m.checkOverlapMatrix(aU) 146 m.checkOverlapMatrix(bU) 147 var restore func() 148 if aTrans && m == aU { 149 m, restore = m.isolatedWorkspace(aU) 150 defer restore() 151 } else if bTrans && m == bU { 152 m, restore = m.isolatedWorkspace(bU) 153 defer restore() 154 } 155 156 for r := 0; r < ar; r++ { 157 for c := 0; c < ac; c++ { 158 m.set(r, c, a.At(r, c)*b.At(r, c)) 159 } 160 } 161 } 162 163 // DivElem performs element-wise division of a by b, placing the result 164 // in the receiver. DivElem will panic if the two matrices do not have the same 165 // shape. 166 func (m *Dense) DivElem(a, b Matrix) { 167 ar, ac := a.Dims() 168 br, bc := b.Dims() 169 if ar != br || ac != bc { 170 panic(ErrShape) 171 } 172 173 aU, aTrans := untransposeExtract(a) 174 bU, bTrans := untransposeExtract(b) 175 m.reuseAsNonZeroed(ar, ac) 176 177 if arm, ok := a.(*Dense); ok { 178 if brm, ok := b.(*Dense); ok { 179 amat, bmat := arm.mat, brm.mat 180 if m != aU { 181 m.checkOverlap(amat) 182 } 183 if m != bU { 184 m.checkOverlap(bmat) 185 } 186 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride { 187 for i, v := range amat.Data[ja : ja+ac] { 188 m.mat.Data[i+jm] = v / bmat.Data[i+jb] 189 } 190 } 191 return 192 } 193 } 194 195 m.checkOverlapMatrix(aU) 196 m.checkOverlapMatrix(bU) 197 var restore func() 198 if aTrans && m == aU { 199 m, restore = m.isolatedWorkspace(aU) 200 defer restore() 201 } else if bTrans && m == bU { 202 m, restore = m.isolatedWorkspace(bU) 203 defer restore() 204 } 205 206 for r := 0; r < ar; r++ { 207 for c := 0; c < ac; c++ { 208 m.set(r, c, a.At(r, c)/b.At(r, c)) 209 } 210 } 211 } 212 213 // Inverse computes the inverse of the matrix a, storing the result into the 214 // receiver. If a is ill-conditioned, a Condition error will be returned. 215 // Note that matrix inversion is numerically unstable, and should generally 216 // be avoided where possible, for example by using the Solve routines. 217 func (m *Dense) Inverse(a Matrix) error { 218 // TODO(btracey): Special case for RawTriangular, etc. 219 r, c := a.Dims() 220 if r != c { 221 panic(ErrSquare) 222 } 223 m.reuseAsNonZeroed(a.Dims()) 224 aU, aTrans := untransposeExtract(a) 225 switch rm := aU.(type) { 226 case *Dense: 227 if m != aU || aTrans { 228 if m == aU || m.checkOverlap(rm.mat) { 229 tmp := getWorkspace(r, c, false) 230 tmp.Copy(a) 231 m.Copy(tmp) 232 putWorkspace(tmp) 233 break 234 } 235 m.Copy(a) 236 } 237 default: 238 m.Copy(a) 239 } 240 // Compute the norm of A. 241 work := getFloats(4*r, false) // Length must be at least 4*r for Gecon. 242 norm := lapack64.Lange(CondNorm, m.mat, work) 243 // Compute the LU factorization of A. 244 ipiv := getInts(r, false) 245 defer putInts(ipiv) 246 ok := lapack64.Getrf(m.mat, ipiv) 247 if !ok { 248 // A is exactly singular. 249 return Condition(math.Inf(1)) 250 } 251 // Compute the condition number of A using the LU factorization. 252 iwork := getInts(r, false) 253 defer putInts(iwork) 254 rcond := lapack64.Gecon(CondNorm, m.mat, norm, work, iwork) 255 // Compute A^{-1} from the LU factorization regardless of the value of rcond. 256 lapack64.Getri(m.mat, ipiv, work, -1) 257 if int(work[0]) > len(work) { 258 l := int(work[0]) 259 putFloats(work) 260 work = getFloats(l, false) 261 } 262 defer putFloats(work) 263 ok = lapack64.Getri(m.mat, ipiv, work, len(work)) 264 if !ok || rcond == 0 { 265 // A is exactly singular. 266 return Condition(math.Inf(1)) 267 } 268 // Check whether A is singular for computational purposes. 269 cond := 1 / rcond 270 if cond > ConditionTolerance { 271 return Condition(cond) 272 } 273 return nil 274 } 275 276 // Mul takes the matrix product of a and b, placing the result in the receiver. 277 // If the number of columns in a does not equal the number of rows in b, Mul will panic. 278 func (m *Dense) Mul(a, b Matrix) { 279 ar, ac := a.Dims() 280 br, bc := b.Dims() 281 282 if ac != br { 283 panic(ErrShape) 284 } 285 286 aU, aTrans := untransposeExtract(a) 287 bU, bTrans := untransposeExtract(b) 288 m.reuseAsNonZeroed(ar, bc) 289 var restore func() 290 if m == aU { 291 m, restore = m.isolatedWorkspace(aU) 292 defer restore() 293 } else if m == bU { 294 m, restore = m.isolatedWorkspace(bU) 295 defer restore() 296 } 297 aT := blas.NoTrans 298 if aTrans { 299 aT = blas.Trans 300 } 301 bT := blas.NoTrans 302 if bTrans { 303 bT = blas.Trans 304 } 305 306 // Some of the cases do not have a transpose option, so create 307 // temporary memory. 308 // C = Aᵀ * B = (Bᵀ * A)ᵀ 309 // Cᵀ = Bᵀ * A. 310 if aU, ok := aU.(*Dense); ok { 311 if restore == nil { 312 m.checkOverlap(aU.mat) 313 } 314 switch bU := bU.(type) { 315 case *Dense: 316 if restore == nil { 317 m.checkOverlap(bU.mat) 318 } 319 blas64.Gemm(aT, bT, 1, aU.mat, bU.mat, 0, m.mat) 320 return 321 322 case *SymDense: 323 if aTrans { 324 c := getWorkspace(ac, ar, false) 325 blas64.Symm(blas.Left, 1, bU.mat, aU.mat, 0, c.mat) 326 strictCopy(m, c.T()) 327 putWorkspace(c) 328 return 329 } 330 blas64.Symm(blas.Right, 1, bU.mat, aU.mat, 0, m.mat) 331 return 332 333 case *TriDense: 334 // Trmm updates in place, so copy aU first. 335 if aTrans { 336 c := getWorkspace(ac, ar, false) 337 var tmp Dense 338 tmp.SetRawMatrix(aU.mat) 339 c.Copy(&tmp) 340 bT := blas.Trans 341 if bTrans { 342 bT = blas.NoTrans 343 } 344 blas64.Trmm(blas.Left, bT, 1, bU.mat, c.mat) 345 strictCopy(m, c.T()) 346 putWorkspace(c) 347 return 348 } 349 m.Copy(a) 350 blas64.Trmm(blas.Right, bT, 1, bU.mat, m.mat) 351 return 352 353 case *VecDense: 354 m.checkOverlap(bU.asGeneral()) 355 bvec := bU.RawVector() 356 if bTrans { 357 // {ar,1} x {1,bc}, which is not a vector. 358 // Instead, construct B as a General. 359 bmat := blas64.General{ 360 Rows: bc, 361 Cols: 1, 362 Stride: bvec.Inc, 363 Data: bvec.Data, 364 } 365 blas64.Gemm(aT, bT, 1, aU.mat, bmat, 0, m.mat) 366 return 367 } 368 cvec := blas64.Vector{ 369 Inc: m.mat.Stride, 370 Data: m.mat.Data, 371 } 372 blas64.Gemv(aT, 1, aU.mat, bvec, 0, cvec) 373 return 374 } 375 } 376 if bU, ok := bU.(*Dense); ok { 377 if restore == nil { 378 m.checkOverlap(bU.mat) 379 } 380 switch aU := aU.(type) { 381 case *SymDense: 382 if bTrans { 383 c := getWorkspace(bc, br, false) 384 blas64.Symm(blas.Right, 1, aU.mat, bU.mat, 0, c.mat) 385 strictCopy(m, c.T()) 386 putWorkspace(c) 387 return 388 } 389 blas64.Symm(blas.Left, 1, aU.mat, bU.mat, 0, m.mat) 390 return 391 392 case *TriDense: 393 // Trmm updates in place, so copy bU first. 394 if bTrans { 395 c := getWorkspace(bc, br, false) 396 var tmp Dense 397 tmp.SetRawMatrix(bU.mat) 398 c.Copy(&tmp) 399 aT := blas.Trans 400 if aTrans { 401 aT = blas.NoTrans 402 } 403 blas64.Trmm(blas.Right, aT, 1, aU.mat, c.mat) 404 strictCopy(m, c.T()) 405 putWorkspace(c) 406 return 407 } 408 m.Copy(b) 409 blas64.Trmm(blas.Left, aT, 1, aU.mat, m.mat) 410 return 411 412 case *VecDense: 413 m.checkOverlap(aU.asGeneral()) 414 avec := aU.RawVector() 415 if aTrans { 416 // {1,ac} x {ac, bc} 417 // Transpose B so that the vector is on the right. 418 cvec := blas64.Vector{ 419 Inc: 1, 420 Data: m.mat.Data, 421 } 422 bT := blas.Trans 423 if bTrans { 424 bT = blas.NoTrans 425 } 426 blas64.Gemv(bT, 1, bU.mat, avec, 0, cvec) 427 return 428 } 429 // {ar,1} x {1,bc} which is not a vector result. 430 // Instead, construct A as a General. 431 amat := blas64.General{ 432 Rows: ar, 433 Cols: 1, 434 Stride: avec.Inc, 435 Data: avec.Data, 436 } 437 blas64.Gemm(aT, bT, 1, amat, bU.mat, 0, m.mat) 438 return 439 } 440 } 441 442 m.checkOverlapMatrix(aU) 443 m.checkOverlapMatrix(bU) 444 row := getFloats(ac, false) 445 defer putFloats(row) 446 for r := 0; r < ar; r++ { 447 for i := range row { 448 row[i] = a.At(r, i) 449 } 450 for c := 0; c < bc; c++ { 451 var v float64 452 for i, e := range row { 453 v += e * b.At(i, c) 454 } 455 m.mat.Data[r*m.mat.Stride+c] = v 456 } 457 } 458 } 459 460 // strictCopy copies a into m panicking if the shape of a and m differ. 461 func strictCopy(m *Dense, a Matrix) { 462 r, c := m.Copy(a) 463 if r != m.mat.Rows || c != m.mat.Cols { 464 // Panic with a string since this 465 // is not a user-facing panic. 466 panic(ErrShape.Error()) 467 } 468 } 469 470 // Exp calculates the exponential of the matrix a, e^a, placing the result 471 // in the receiver. Exp will panic with matrix.ErrShape if a is not square. 472 func (m *Dense) Exp(a Matrix) { 473 // The implementation used here is from Functions of Matrices: Theory and Computation 474 // Chapter 10, Algorithm 10.20. https://doi.org/10.1137/1.9780898717778.ch10 475 476 r, c := a.Dims() 477 if r != c { 478 panic(ErrShape) 479 } 480 481 m.reuseAsNonZeroed(r, r) 482 if r == 1 { 483 m.mat.Data[0] = math.Exp(a.At(0, 0)) 484 return 485 } 486 487 pade := []struct { 488 theta float64 489 b []float64 490 }{ 491 {theta: 0.015, b: []float64{ 492 120, 60, 12, 1, 493 }}, 494 {theta: 0.25, b: []float64{ 495 30240, 15120, 3360, 420, 30, 1, 496 }}, 497 {theta: 0.95, b: []float64{ 498 17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1, 499 }}, 500 {theta: 2.1, b: []float64{ 501 17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1, 502 }}, 503 } 504 505 a1 := m 506 a1.Copy(a) 507 v := getWorkspace(r, r, true) 508 vraw := v.RawMatrix() 509 n := r * r 510 vvec := blas64.Vector{N: n, Inc: 1, Data: vraw.Data} 511 defer putWorkspace(v) 512 513 u := getWorkspace(r, r, true) 514 uraw := u.RawMatrix() 515 uvec := blas64.Vector{N: n, Inc: 1, Data: uraw.Data} 516 defer putWorkspace(u) 517 518 a2 := getWorkspace(r, r, false) 519 defer putWorkspace(a2) 520 521 n1 := Norm(a, 1) 522 for i, t := range pade { 523 if n1 > t.theta { 524 continue 525 } 526 527 // This loop only executes once, so 528 // this is not as horrible as it looks. 529 p := getWorkspace(r, r, true) 530 praw := p.RawMatrix() 531 pvec := blas64.Vector{N: n, Inc: 1, Data: praw.Data} 532 defer putWorkspace(p) 533 534 for k := 0; k < r; k++ { 535 p.set(k, k, 1) 536 v.set(k, k, t.b[0]) 537 u.set(k, k, t.b[1]) 538 } 539 540 a2.Mul(a1, a1) 541 for j := 0; j <= i; j++ { 542 p.Mul(p, a2) 543 blas64.Axpy(t.b[2*j+2], pvec, vvec) 544 blas64.Axpy(t.b[2*j+3], pvec, uvec) 545 } 546 u.Mul(a1, u) 547 548 // Use p as a workspace here and 549 // rename u for the second call's 550 // receiver. 551 vmu, vpu := u, p 552 vpu.Add(v, u) 553 vmu.Sub(v, u) 554 555 _ = m.Solve(vmu, vpu) 556 return 557 } 558 559 // Remaining Padé table line. 560 const theta13 = 5.4 561 b := [...]float64{ 562 64764752532480000, 32382376266240000, 7771770303897600, 1187353796428800, 563 129060195264000, 10559470521600, 670442572800, 33522128640, 564 1323241920, 40840800, 960960, 16380, 182, 1, 565 } 566 567 s := math.Log2(n1 / theta13) 568 if s >= 0 { 569 s = math.Ceil(s) 570 a1.Scale(1/math.Pow(2, s), a1) 571 } 572 a2.Mul(a1, a1) 573 574 i := getWorkspace(r, r, true) 575 for j := 0; j < r; j++ { 576 i.set(j, j, 1) 577 } 578 iraw := i.RawMatrix() 579 ivec := blas64.Vector{N: n, Inc: 1, Data: iraw.Data} 580 defer putWorkspace(i) 581 582 a2raw := a2.RawMatrix() 583 a2vec := blas64.Vector{N: n, Inc: 1, Data: a2raw.Data} 584 585 a4 := getWorkspace(r, r, false) 586 a4raw := a4.RawMatrix() 587 a4vec := blas64.Vector{N: n, Inc: 1, Data: a4raw.Data} 588 defer putWorkspace(a4) 589 a4.Mul(a2, a2) 590 591 a6 := getWorkspace(r, r, false) 592 a6raw := a6.RawMatrix() 593 a6vec := blas64.Vector{N: n, Inc: 1, Data: a6raw.Data} 594 defer putWorkspace(a6) 595 a6.Mul(a2, a4) 596 597 // V = A_6(b_12*A_6 + b_10*A_4 + b_8*A_2) + b_6*A_6 + b_4*A_4 + b_2*A_2 +b_0*I 598 blas64.Axpy(b[12], a6vec, vvec) 599 blas64.Axpy(b[10], a4vec, vvec) 600 blas64.Axpy(b[8], a2vec, vvec) 601 v.Mul(v, a6) 602 blas64.Axpy(b[6], a6vec, vvec) 603 blas64.Axpy(b[4], a4vec, vvec) 604 blas64.Axpy(b[2], a2vec, vvec) 605 blas64.Axpy(b[0], ivec, vvec) 606 607 // U = A(A_6(b_13*A_6 + b_11*A_4 + b_9*A_2) + b_7*A_6 + b_5*A_4 + b_2*A_3 +b_1*I) 608 blas64.Axpy(b[13], a6vec, uvec) 609 blas64.Axpy(b[11], a4vec, uvec) 610 blas64.Axpy(b[9], a2vec, uvec) 611 u.Mul(u, a6) 612 blas64.Axpy(b[7], a6vec, uvec) 613 blas64.Axpy(b[5], a4vec, uvec) 614 blas64.Axpy(b[3], a2vec, uvec) 615 blas64.Axpy(b[1], ivec, uvec) 616 u.Mul(u, a1) 617 618 // Use i as a workspace here and 619 // rename u for the second call's 620 // receiver. 621 vmu, vpu := u, i 622 vpu.Add(v, u) 623 vmu.Sub(v, u) 624 625 _ = m.Solve(vmu, vpu) 626 627 for ; s > 0; s-- { 628 m.Mul(m, m) 629 } 630 } 631 632 // Pow calculates the integral power of the matrix a to n, placing the result 633 // in the receiver. Pow will panic if n is negative or if a is not square. 634 func (m *Dense) Pow(a Matrix, n int) { 635 if n < 0 { 636 panic("mat: illegal power") 637 } 638 r, c := a.Dims() 639 if r != c { 640 panic(ErrShape) 641 } 642 643 m.reuseAsNonZeroed(r, c) 644 645 // Take possible fast paths. 646 switch n { 647 case 0: 648 for i := 0; i < r; i++ { 649 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c]) 650 m.mat.Data[i*m.mat.Stride+i] = 1 651 } 652 return 653 case 1: 654 m.Copy(a) 655 return 656 case 2: 657 m.Mul(a, a) 658 return 659 } 660 661 // Perform iterative exponentiation by squaring in work space. 662 w := getWorkspace(r, r, false) 663 w.Copy(a) 664 s := getWorkspace(r, r, false) 665 s.Copy(a) 666 x := getWorkspace(r, r, false) 667 for n--; n > 0; n >>= 1 { 668 if n&1 != 0 { 669 x.Mul(w, s) 670 w, x = x, w 671 } 672 if n != 1 { 673 x.Mul(s, s) 674 s, x = x, s 675 } 676 } 677 m.Copy(w) 678 putWorkspace(w) 679 putWorkspace(s) 680 putWorkspace(x) 681 } 682 683 // Kronecker calculates the Kronecker product of a and b, placing the result in 684 // the receiver. 685 func (m *Dense) Kronecker(a, b Matrix) { 686 ra, ca := a.Dims() 687 rb, cb := b.Dims() 688 689 m.reuseAsNonZeroed(ra*rb, ca*cb) 690 for i := 0; i < ra; i++ { 691 for j := 0; j < ca; j++ { 692 m.slice(i*rb, (i+1)*rb, j*cb, (j+1)*cb).Scale(a.At(i, j), b) 693 } 694 } 695 } 696 697 // Scale multiplies the elements of a by f, placing the result in the receiver. 698 // 699 // See the Scaler interface for more information. 700 func (m *Dense) Scale(f float64, a Matrix) { 701 ar, ac := a.Dims() 702 703 m.reuseAsNonZeroed(ar, ac) 704 705 aU, aTrans := untransposeExtract(a) 706 if rm, ok := aU.(*Dense); ok { 707 amat := rm.mat 708 if m == aU || m.checkOverlap(amat) { 709 var restore func() 710 m, restore = m.isolatedWorkspace(a) 711 defer restore() 712 } 713 if !aTrans { 714 for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride { 715 for i, v := range amat.Data[ja : ja+ac] { 716 m.mat.Data[i+jm] = v * f 717 } 718 } 719 } else { 720 for ja, jm := 0, 0; ja < ac*amat.Stride; ja, jm = ja+amat.Stride, jm+1 { 721 for i, v := range amat.Data[ja : ja+ar] { 722 m.mat.Data[i*m.mat.Stride+jm] = v * f 723 } 724 } 725 } 726 return 727 } 728 729 m.checkOverlapMatrix(a) 730 for r := 0; r < ar; r++ { 731 for c := 0; c < ac; c++ { 732 m.set(r, c, f*a.At(r, c)) 733 } 734 } 735 } 736 737 // Apply applies the function fn to each of the elements of a, placing the 738 // resulting matrix in the receiver. The function fn takes a row/column 739 // index and element value and returns some function of that tuple. 740 func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) { 741 ar, ac := a.Dims() 742 743 m.reuseAsNonZeroed(ar, ac) 744 745 aU, aTrans := untransposeExtract(a) 746 if rm, ok := aU.(*Dense); ok { 747 amat := rm.mat 748 if m == aU || m.checkOverlap(amat) { 749 var restore func() 750 m, restore = m.isolatedWorkspace(a) 751 defer restore() 752 } 753 if !aTrans { 754 for j, ja, jm := 0, 0, 0; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride { 755 for i, v := range amat.Data[ja : ja+ac] { 756 m.mat.Data[i+jm] = fn(j, i, v) 757 } 758 } 759 } else { 760 for j, ja, jm := 0, 0, 0; ja < ac*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+1 { 761 for i, v := range amat.Data[ja : ja+ar] { 762 m.mat.Data[i*m.mat.Stride+jm] = fn(i, j, v) 763 } 764 } 765 } 766 return 767 } 768 769 m.checkOverlapMatrix(a) 770 for r := 0; r < ar; r++ { 771 for c := 0; c < ac; c++ { 772 m.set(r, c, fn(r, c, a.At(r, c))) 773 } 774 } 775 } 776 777 // RankOne performs a rank-one update to the matrix a with the vectors x and 778 // y, where x and y are treated as column vectors. The result is stored in the 779 // receiver. The Outer method can be used instead of RankOne if a is not needed. 780 // m = a + alpha * x * yᵀ 781 func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) { 782 ar, ac := a.Dims() 783 if x.Len() != ar { 784 panic(ErrShape) 785 } 786 if y.Len() != ac { 787 panic(ErrShape) 788 } 789 790 if a != m { 791 aU, _ := untransposeExtract(a) 792 if rm, ok := aU.(*Dense); ok { 793 m.checkOverlap(rm.RawMatrix()) 794 } 795 } 796 797 var xmat, ymat blas64.Vector 798 fast := true 799 xU, _ := untransposeExtract(x) 800 if rv, ok := xU.(*VecDense); ok { 801 r, c := xU.Dims() 802 xmat = rv.mat 803 m.checkOverlap(generalFromVector(xmat, r, c)) 804 } else { 805 fast = false 806 } 807 yU, _ := untransposeExtract(y) 808 if rv, ok := yU.(*VecDense); ok { 809 r, c := yU.Dims() 810 ymat = rv.mat 811 m.checkOverlap(generalFromVector(ymat, r, c)) 812 } else { 813 fast = false 814 } 815 816 if fast { 817 if m != a { 818 m.reuseAsNonZeroed(ar, ac) 819 m.Copy(a) 820 } 821 blas64.Ger(alpha, xmat, ymat, m.mat) 822 return 823 } 824 825 m.reuseAsNonZeroed(ar, ac) 826 for i := 0; i < ar; i++ { 827 for j := 0; j < ac; j++ { 828 m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j)) 829 } 830 } 831 } 832 833 // Outer calculates the outer product of the vectors x and y, where x and y 834 // are treated as column vectors, and stores the result in the receiver. 835 // m = alpha * x * yᵀ 836 // In order to update an existing matrix, see RankOne. 837 func (m *Dense) Outer(alpha float64, x, y Vector) { 838 r, c := x.Len(), y.Len() 839 840 m.reuseAsZeroed(r, c) 841 842 var xmat, ymat blas64.Vector 843 fast := true 844 xU, _ := untransposeExtract(x) 845 if rv, ok := xU.(*VecDense); ok { 846 r, c := xU.Dims() 847 xmat = rv.mat 848 m.checkOverlap(generalFromVector(xmat, r, c)) 849 } else { 850 fast = false 851 } 852 yU, _ := untransposeExtract(y) 853 if rv, ok := yU.(*VecDense); ok { 854 r, c := yU.Dims() 855 ymat = rv.mat 856 m.checkOverlap(generalFromVector(ymat, r, c)) 857 } else { 858 fast = false 859 } 860 861 if fast { 862 for i := 0; i < r; i++ { 863 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c]) 864 } 865 blas64.Ger(alpha, xmat, ymat, m.mat) 866 return 867 } 868 869 for i := 0; i < r; i++ { 870 for j := 0; j < c; j++ { 871 m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j)) 872 } 873 } 874 }