gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/lu.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/floats" 13 "gonum.org/v1/gonum/lapack" 14 "gonum.org/v1/gonum/lapack/lapack64" 15 ) 16 17 const ( 18 badSliceLength = "mat: improper slice length" 19 badLU = "mat: invalid LU factorization" 20 ) 21 22 // LU is a square n×n matrix represented by its LU factorization with partial 23 // pivoting. 24 // 25 // The factorization has the form 26 // 27 // A = P * L * U 28 // 29 // where P is a permutation matrix, L is lower triangular with unit diagonal 30 // elements, and U is upper triangular. 31 // 32 // Note that this matrix representation is useful for certain operations, in 33 // particular for solving linear systems of equations. It is very inefficient at 34 // other operations, in particular At is slow. 35 type LU struct { 36 lu *Dense 37 swaps []int 38 piv []int 39 cond float64 40 ok bool // Whether A is nonsingular 41 } 42 43 var _ Matrix = (*LU)(nil) 44 45 // Dims returns the dimensions of the matrix A. 46 func (lu *LU) Dims() (r, c int) { 47 if lu.lu == nil { 48 return 0, 0 49 } 50 return lu.lu.Dims() 51 } 52 53 // At returns the element of A at row i, column j. 54 func (lu *LU) At(i, j int) float64 { 55 n, _ := lu.Dims() 56 if uint(i) >= uint(n) { 57 panic(ErrRowAccess) 58 } 59 if uint(j) >= uint(n) { 60 panic(ErrColAccess) 61 } 62 63 i = lu.piv[i] 64 var val float64 65 for k := 0; k < min(i, j+1); k++ { 66 val += lu.lu.at(i, k) * lu.lu.at(k, j) 67 } 68 if i <= j { 69 val += lu.lu.at(i, j) 70 } 71 return val 72 } 73 74 // T performs an implicit transpose by returning the receiver inside a 75 // Transpose. 76 func (lu *LU) T() Matrix { 77 return Transpose{lu} 78 } 79 80 // updateCond updates the stored condition number of the matrix. anorm is the 81 // norm of the original matrix. If anorm is negative it will be estimated. 82 func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) { 83 n := lu.lu.mat.Cols 84 work := getFloat64s(4*n, false) 85 defer putFloat64s(work) 86 iwork := getInts(n, false) 87 defer putInts(iwork) 88 if anorm < 0 { 89 // This is an approximation. By the definition of a norm, 90 // |AB| <= |A| |B|. 91 // Since A = L*U, we get for the condition number κ that 92 // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|, 93 // so this will overestimate the condition number somewhat. 94 // The norm of the original factorized matrix cannot be stored 95 // because of update possibilities. 96 u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper) 97 l := lu.lu.asTriDense(n, blas.Unit, blas.Lower) 98 unorm := lapack64.Lantr(norm, u.mat, work) 99 lnorm := lapack64.Lantr(norm, l.mat, work) 100 anorm = unorm * lnorm 101 } 102 v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork) 103 lu.cond = 1 / v 104 } 105 106 // Factorize computes the LU factorization of the square matrix A and stores the 107 // result in the receiver. The LU decomposition will complete regardless of the 108 // singularity of a. 109 // 110 // The L and U matrix factors can be extracted from the factorization using the 111 // LTo and UTo methods. The matrix P can be extracted as a row permutation using 112 // the RowPivots method and applied using Dense.PermuteRows. 113 func (lu *LU) Factorize(a Matrix) { 114 lu.factorize(a, CondNorm) 115 } 116 117 func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) { 118 m, n := a.Dims() 119 if m != n { 120 panic(ErrSquare) 121 } 122 if lu.lu == nil { 123 lu.lu = NewDense(n, n, nil) 124 } else { 125 lu.lu.Reset() 126 lu.lu.reuseAsNonZeroed(n, n) 127 } 128 lu.lu.Copy(a) 129 lu.swaps = useInt(lu.swaps, n) 130 lu.piv = useInt(lu.piv, n) 131 work := getFloat64s(n, false) 132 anorm := lapack64.Lange(norm, lu.lu.mat, work) 133 putFloat64s(work) 134 lu.ok = lapack64.Getrf(lu.lu.mat, lu.swaps) 135 lu.updatePivots(lu.swaps) 136 lu.updateCond(anorm, norm) 137 } 138 139 func (lu *LU) updatePivots(swaps []int) { 140 // Replay the sequence of row swaps in order to find the row permutation. 141 for i := range lu.piv { 142 lu.piv[i] = i 143 } 144 n, _ := lu.Dims() 145 for i := n - 1; i >= 0; i-- { 146 v := swaps[i] 147 lu.piv[i], lu.piv[v] = lu.piv[v], lu.piv[i] 148 } 149 } 150 151 // isValid returns whether the receiver contains a factorization. 152 func (lu *LU) isValid() bool { 153 return lu.lu != nil && !lu.lu.IsEmpty() 154 } 155 156 // Cond returns the condition number for the factorized matrix. 157 // Cond will panic if the receiver does not contain a factorization. 158 func (lu *LU) Cond() float64 { 159 if !lu.isValid() { 160 panic(badLU) 161 } 162 return lu.cond 163 } 164 165 // Reset resets the factorization so that it can be reused as the receiver of a 166 // dimensionally restricted operation. 167 func (lu *LU) Reset() { 168 if lu.lu != nil { 169 lu.lu.Reset() 170 } 171 lu.swaps = lu.swaps[:0] 172 lu.piv = lu.piv[:0] 173 } 174 175 func (lu *LU) isZero() bool { 176 return len(lu.swaps) == 0 177 } 178 179 // Det returns the determinant of the matrix that has been factorized. In many 180 // expressions, using LogDet will be more numerically stable. 181 // Det will panic if the receiver does not contain a factorization. 182 func (lu *LU) Det() float64 { 183 if !lu.ok { 184 return 0 185 } 186 det, sign := lu.LogDet() 187 return math.Exp(det) * sign 188 } 189 190 // LogDet returns the log of the determinant and the sign of the determinant 191 // for the matrix that has been factorized. Numerical stability in product and 192 // division expressions is generally improved by working in log space. 193 // LogDet will panic if the receiver does not contain a factorization. 194 func (lu *LU) LogDet() (det float64, sign float64) { 195 if !lu.isValid() { 196 panic(badLU) 197 } 198 199 _, n := lu.lu.Dims() 200 logDiag := getFloat64s(n, false) 201 defer putFloat64s(logDiag) 202 sign = 1.0 203 for i := 0; i < n; i++ { 204 v := lu.lu.at(i, i) 205 if v < 0 { 206 sign *= -1 207 } 208 if lu.swaps[i] != i { 209 sign *= -1 210 } 211 logDiag[i] = math.Log(math.Abs(v)) 212 } 213 return floats.Sum(logDiag), sign 214 } 215 216 // RowPivots returns the row permutation that represents the permutation matrix 217 // P from the LU factorization 218 // 219 // A = P * L * U. 220 // 221 // If dst is nil, a new slice is allocated and returned. If dst is not nil and 222 // the length of dst does not equal the size of the factorized matrix, RowPivots 223 // will panic. RowPivots will panic if the receiver does not contain a 224 // factorization. 225 func (lu *LU) RowPivots(dst []int) []int { 226 if !lu.isValid() { 227 panic(badLU) 228 } 229 _, n := lu.lu.Dims() 230 if dst == nil { 231 dst = make([]int, n) 232 } 233 if len(dst) != n { 234 panic(badSliceLength) 235 } 236 copy(dst, lu.piv) 237 return dst 238 } 239 240 // Pivot returns the row pivots of the receiver. 241 // 242 // Deprecated: Use RowPivots instead. 243 func (lu *LU) Pivot(dst []int) []int { 244 return lu.RowPivots(dst) 245 } 246 247 // RankOne updates an LU factorization as if a rank-one update had been applied to 248 // the original matrix A, storing the result into the receiver. That is, if in 249 // the original LU decomposition P * L * U = A, in the updated decomposition 250 // P * L' * U' = A + alpha * x * yᵀ. 251 // RankOne will panic if orig does not contain a factorization. 252 func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) { 253 if !orig.isValid() { 254 panic(badLU) 255 } 256 257 // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix 258 // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng. 259 // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf 260 _, n := orig.lu.Dims() 261 if r, c := x.Dims(); r != n || c != 1 { 262 panic(ErrShape) 263 } 264 if r, c := y.Dims(); r != n || c != 1 { 265 panic(ErrShape) 266 } 267 if orig != lu { 268 if lu.isZero() { 269 lu.swaps = useInt(lu.swaps, n) 270 lu.piv = useInt(lu.piv, n) 271 if lu.lu == nil { 272 lu.lu = NewDense(n, n, nil) 273 } else { 274 lu.lu.reuseAsNonZeroed(n, n) 275 } 276 } else if len(lu.swaps) != n { 277 panic(ErrShape) 278 } 279 copy(lu.swaps, orig.swaps) 280 lu.updatePivots(lu.swaps) 281 lu.lu.Copy(orig.lu) 282 } 283 284 xs := getFloat64s(n, false) 285 defer putFloat64s(xs) 286 ys := getFloat64s(n, false) 287 defer putFloat64s(ys) 288 for i := 0; i < n; i++ { 289 xs[i] = x.AtVec(i) 290 ys[i] = y.AtVec(i) 291 } 292 293 // Adjust for the pivoting in the LU factorization 294 for i, v := range lu.swaps { 295 xs[i], xs[v] = xs[v], xs[i] 296 } 297 298 lum := lu.lu.mat 299 omega := alpha 300 for j := 0; j < n; j++ { 301 ujj := lum.Data[j*lum.Stride+j] 302 ys[j] /= ujj 303 theta := 1 + xs[j]*ys[j]*omega 304 beta := omega * ys[j] / theta 305 gamma := omega * xs[j] 306 omega -= beta * gamma 307 lum.Data[j*lum.Stride+j] *= theta 308 for i := j + 1; i < n; i++ { 309 xs[i] -= lum.Data[i*lum.Stride+j] * xs[j] 310 tmp := ys[i] 311 ys[i] -= lum.Data[j*lum.Stride+i] * ys[j] 312 lum.Data[i*lum.Stride+j] += beta * xs[i] 313 lum.Data[j*lum.Stride+i] += gamma * tmp 314 } 315 } 316 lu.updateCond(-1, CondNorm) 317 } 318 319 // LTo extracts the lower triangular matrix from an LU factorization. 320 // 321 // If dst is empty, LTo will resize dst to be a lower-triangular n×n matrix. 322 // When dst is non-empty, LTo will panic if dst is not n×n or not Lower. 323 // LTo will also panic if the receiver does not contain a successful 324 // factorization. 325 func (lu *LU) LTo(dst *TriDense) *TriDense { 326 if !lu.isValid() { 327 panic(badLU) 328 } 329 330 _, n := lu.lu.Dims() 331 if dst.IsEmpty() { 332 dst.ReuseAsTri(n, Lower) 333 } else { 334 n2, kind := dst.Triangle() 335 if n != n2 { 336 panic(ErrShape) 337 } 338 if kind != Lower { 339 panic(ErrTriangle) 340 } 341 } 342 // Extract the lower triangular elements. 343 for i := 1; i < n; i++ { 344 copy(dst.mat.Data[i*dst.mat.Stride:i*dst.mat.Stride+i], lu.lu.mat.Data[i*lu.lu.mat.Stride:i*lu.lu.mat.Stride+i]) 345 } 346 // Set ones on the diagonal. 347 for i := 0; i < n; i++ { 348 dst.mat.Data[i*dst.mat.Stride+i] = 1 349 } 350 return dst 351 } 352 353 // UTo extracts the upper triangular matrix from an LU factorization. 354 // 355 // If dst is empty, UTo will resize dst to be an upper-triangular n×n matrix. 356 // When dst is non-empty, UTo will panic if dst is not n×n or not Upper. 357 // UTo will also panic if the receiver does not contain a successful 358 // factorization. 359 func (lu *LU) UTo(dst *TriDense) { 360 if !lu.isValid() { 361 panic(badLU) 362 } 363 364 _, n := lu.lu.Dims() 365 if dst.IsEmpty() { 366 dst.ReuseAsTri(n, Upper) 367 } else { 368 n2, kind := dst.Triangle() 369 if n != n2 { 370 panic(ErrShape) 371 } 372 if kind != Upper { 373 panic(ErrTriangle) 374 } 375 } 376 // Extract the upper triangular elements. 377 for i := 0; i < n; i++ { 378 copy(dst.mat.Data[i*dst.mat.Stride+i:i*dst.mat.Stride+n], lu.lu.mat.Data[i*lu.lu.mat.Stride+i:i*lu.lu.mat.Stride+n]) 379 } 380 } 381 382 // SolveTo solves a system of linear equations 383 // 384 // A * X = B if trans == false 385 // Aᵀ * X = B if trans == true 386 // 387 // using the LU factorization of A stored in the receiver. The solution matrix X 388 // is stored into dst. 389 // 390 // If A is singular or near-singular a Condition error is returned. See the 391 // documentation for Condition for more information. SolveTo will panic if the 392 // receiver does not contain a factorization. 393 func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error { 394 if !lu.isValid() { 395 panic(badLU) 396 } 397 398 _, n := lu.lu.Dims() 399 br, bc := b.Dims() 400 if br != n { 401 panic(ErrShape) 402 } 403 404 if !lu.ok { 405 return Condition(math.Inf(1)) 406 } 407 408 dst.reuseAsNonZeroed(n, bc) 409 bU, _ := untranspose(b) 410 if dst == bU { 411 var restore func() 412 dst, restore = dst.isolatedWorkspace(bU) 413 defer restore() 414 } else if rm, ok := bU.(RawMatrixer); ok { 415 dst.checkOverlap(rm.RawMatrix()) 416 } 417 418 dst.Copy(b) 419 t := blas.NoTrans 420 if trans { 421 t = blas.Trans 422 } 423 lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.swaps) 424 if lu.cond > ConditionTolerance { 425 return Condition(lu.cond) 426 } 427 return nil 428 } 429 430 // SolveVecTo solves a system of linear equations 431 // 432 // A * x = b if trans == false 433 // Aᵀ * x = b if trans == true 434 // 435 // using the LU factorization of A stored in the receiver. The solution matrix x 436 // is stored into dst. 437 // 438 // If A is singular or near-singular a Condition error is returned. See the 439 // documentation for Condition for more information. SolveVecTo will panic if the 440 // receiver does not contain a factorization. 441 func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 442 if !lu.isValid() { 443 panic(badLU) 444 } 445 446 _, n := lu.lu.Dims() 447 if br, bc := b.Dims(); br != n || bc != 1 { 448 panic(ErrShape) 449 } 450 451 switch rv := b.(type) { 452 default: 453 dst.reuseAsNonZeroed(n) 454 return lu.SolveTo(dst.asDense(), trans, b) 455 case RawVectorer: 456 if dst != b { 457 dst.checkOverlap(rv.RawVector()) 458 } 459 460 if !lu.ok { 461 return Condition(math.Inf(1)) 462 } 463 464 dst.reuseAsNonZeroed(n) 465 var restore func() 466 if dst == b { 467 dst, restore = dst.isolatedWorkspace(b) 468 defer restore() 469 } 470 dst.CopyVec(b) 471 vMat := blas64.General{ 472 Rows: n, 473 Cols: 1, 474 Stride: dst.mat.Inc, 475 Data: dst.mat.Data, 476 } 477 t := blas.NoTrans 478 if trans { 479 t = blas.Trans 480 } 481 lapack64.Getrs(t, lu.lu.mat, vMat, lu.swaps) 482 if lu.cond > ConditionTolerance { 483 return Condition(lu.cond) 484 } 485 return nil 486 } 487 }