github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/lu.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "github.com/jingcheng-WU/gonum/blas" 11 "github.com/jingcheng-WU/gonum/blas/blas64" 12 "github.com/jingcheng-WU/gonum/floats" 13 "github.com/jingcheng-WU/gonum/lapack" 14 "github.com/jingcheng-WU/gonum/lapack/lapack64" 15 ) 16 17 const ( 18 badSliceLength = "mat: improper slice length" 19 badLU = "mat: invalid LU factorization" 20 ) 21 22 // LU is a type for creating and using the LU factorization of a matrix. 23 type LU struct { 24 lu *Dense 25 pivot []int 26 cond float64 27 } 28 29 // updateCond updates the stored condition number of the matrix. anorm is the 30 // norm of the original matrix. If anorm is negative it will be estimated. 31 func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) { 32 n := lu.lu.mat.Cols 33 work := getFloats(4*n, false) 34 defer putFloats(work) 35 iwork := getInts(n, false) 36 defer putInts(iwork) 37 if anorm < 0 { 38 // This is an approximation. By the definition of a norm, 39 // |AB| <= |A| |B|. 40 // Since A = L*U, we get for the condition number κ that 41 // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|, 42 // so this will overestimate the condition number somewhat. 43 // The norm of the original factorized matrix cannot be stored 44 // because of update possibilities. 45 u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper) 46 l := lu.lu.asTriDense(n, blas.Unit, blas.Lower) 47 unorm := lapack64.Lantr(norm, u.mat, work) 48 lnorm := lapack64.Lantr(norm, l.mat, work) 49 anorm = unorm * lnorm 50 } 51 v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork) 52 lu.cond = 1 / v 53 } 54 55 // Factorize computes the LU factorization of the square matrix a and stores the 56 // result. The LU decomposition will complete regardless of the singularity of a. 57 // 58 // The LU factorization is computed with pivoting, and so really the decomposition 59 // is a PLU decomposition where P is a permutation matrix. The individual matrix 60 // factors can be extracted from the factorization using the Permutation method 61 // on Dense, and the LU.LTo and LU.UTo methods. 62 func (lu *LU) Factorize(a Matrix) { 63 lu.factorize(a, CondNorm) 64 } 65 66 func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) { 67 r, c := a.Dims() 68 if r != c { 69 panic(ErrSquare) 70 } 71 if lu.lu == nil { 72 lu.lu = NewDense(r, r, nil) 73 } else { 74 lu.lu.Reset() 75 lu.lu.reuseAsNonZeroed(r, r) 76 } 77 lu.lu.Copy(a) 78 if cap(lu.pivot) < r { 79 lu.pivot = make([]int, r) 80 } 81 lu.pivot = lu.pivot[:r] 82 work := getFloats(r, false) 83 anorm := lapack64.Lange(norm, lu.lu.mat, work) 84 putFloats(work) 85 lapack64.Getrf(lu.lu.mat, lu.pivot) 86 lu.updateCond(anorm, norm) 87 } 88 89 // isValid returns whether the receiver contains a factorization. 90 func (lu *LU) isValid() bool { 91 return lu.lu != nil && !lu.lu.IsEmpty() 92 } 93 94 // Cond returns the condition number for the factorized matrix. 95 // Cond will panic if the receiver does not contain a factorization. 96 func (lu *LU) Cond() float64 { 97 if !lu.isValid() { 98 panic(badLU) 99 } 100 return lu.cond 101 } 102 103 // Reset resets the factorization so that it can be reused as the receiver of a 104 // dimensionally restricted operation. 105 func (lu *LU) Reset() { 106 if lu.lu != nil { 107 lu.lu.Reset() 108 } 109 lu.pivot = lu.pivot[:0] 110 } 111 112 func (lu *LU) isZero() bool { 113 return len(lu.pivot) == 0 114 } 115 116 // Det returns the determinant of the matrix that has been factorized. In many 117 // expressions, using LogDet will be more numerically stable. 118 // Det will panic if the receiver does not contain a factorization. 119 func (lu *LU) Det() float64 { 120 det, sign := lu.LogDet() 121 return math.Exp(det) * sign 122 } 123 124 // LogDet returns the log of the determinant and the sign of the determinant 125 // for the matrix that has been factorized. Numerical stability in product and 126 // division expressions is generally improved by working in log space. 127 // LogDet will panic if the receiver does not contain a factorization. 128 func (lu *LU) LogDet() (det float64, sign float64) { 129 if !lu.isValid() { 130 panic(badLU) 131 } 132 133 _, n := lu.lu.Dims() 134 logDiag := getFloats(n, false) 135 defer putFloats(logDiag) 136 sign = 1.0 137 for i := 0; i < n; i++ { 138 v := lu.lu.at(i, i) 139 if v < 0 { 140 sign *= -1 141 } 142 if lu.pivot[i] != i { 143 sign *= -1 144 } 145 logDiag[i] = math.Log(math.Abs(v)) 146 } 147 return floats.Sum(logDiag), sign 148 } 149 150 // Pivot returns pivot indices that enable the construction of the permutation 151 // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be 152 // allocated, otherwise the length of the input must be equal to the size of the 153 // factorized matrix. 154 // Pivot will panic if the receiver does not contain a factorization. 155 func (lu *LU) Pivot(swaps []int) []int { 156 if !lu.isValid() { 157 panic(badLU) 158 } 159 160 _, n := lu.lu.Dims() 161 if swaps == nil { 162 swaps = make([]int, n) 163 } 164 if len(swaps) != n { 165 panic(badSliceLength) 166 } 167 // Perform the inverse of the row swaps in order to find the final 168 // row swap position. 169 for i := range swaps { 170 swaps[i] = i 171 } 172 for i := n - 1; i >= 0; i-- { 173 v := lu.pivot[i] 174 swaps[i], swaps[v] = swaps[v], swaps[i] 175 } 176 return swaps 177 } 178 179 // RankOne updates an LU factorization as if a rank-one update had been applied to 180 // the original matrix A, storing the result into the receiver. That is, if in 181 // the original LU decomposition P * L * U = A, in the updated decomposition 182 // P * L * U = A + alpha * x * yᵀ. 183 // RankOne will panic if orig does not contain a factorization. 184 func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) { 185 if !orig.isValid() { 186 panic(badLU) 187 } 188 189 // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix 190 // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng. 191 // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf 192 _, n := orig.lu.Dims() 193 if r, c := x.Dims(); r != n || c != 1 { 194 panic(ErrShape) 195 } 196 if r, c := y.Dims(); r != n || c != 1 { 197 panic(ErrShape) 198 } 199 if orig != lu { 200 if lu.isZero() { 201 if cap(lu.pivot) < n { 202 lu.pivot = make([]int, n) 203 } 204 lu.pivot = lu.pivot[:n] 205 if lu.lu == nil { 206 lu.lu = NewDense(n, n, nil) 207 } else { 208 lu.lu.reuseAsNonZeroed(n, n) 209 } 210 } else if len(lu.pivot) != n { 211 panic(ErrShape) 212 } 213 copy(lu.pivot, orig.pivot) 214 lu.lu.Copy(orig.lu) 215 } 216 217 xs := getFloats(n, false) 218 defer putFloats(xs) 219 ys := getFloats(n, false) 220 defer putFloats(ys) 221 for i := 0; i < n; i++ { 222 xs[i] = x.AtVec(i) 223 ys[i] = y.AtVec(i) 224 } 225 226 // Adjust for the pivoting in the LU factorization 227 for i, v := range lu.pivot { 228 xs[i], xs[v] = xs[v], xs[i] 229 } 230 231 lum := lu.lu.mat 232 omega := alpha 233 for j := 0; j < n; j++ { 234 ujj := lum.Data[j*lum.Stride+j] 235 ys[j] /= ujj 236 theta := 1 + xs[j]*ys[j]*omega 237 beta := omega * ys[j] / theta 238 gamma := omega * xs[j] 239 omega -= beta * gamma 240 lum.Data[j*lum.Stride+j] *= theta 241 for i := j + 1; i < n; i++ { 242 xs[i] -= lum.Data[i*lum.Stride+j] * xs[j] 243 tmp := ys[i] 244 ys[i] -= lum.Data[j*lum.Stride+i] * ys[j] 245 lum.Data[i*lum.Stride+j] += beta * xs[i] 246 lum.Data[j*lum.Stride+i] += gamma * tmp 247 } 248 } 249 lu.updateCond(-1, CondNorm) 250 } 251 252 // LTo extracts the lower triangular matrix from an LU factorization. 253 // 254 // If dst is empty, LTo will resize dst to be a lower-triangular n×n matrix. 255 // When dst is non-empty, LTo will panic if dst is not n×n or not Lower. 256 // LTo will also panic if the receiver does not contain a successful 257 // factorization. 258 func (lu *LU) LTo(dst *TriDense) *TriDense { 259 if !lu.isValid() { 260 panic(badLU) 261 } 262 263 _, n := lu.lu.Dims() 264 if dst.IsEmpty() { 265 dst.ReuseAsTri(n, Lower) 266 } else { 267 n2, kind := dst.Triangle() 268 if n != n2 { 269 panic(ErrShape) 270 } 271 if kind != Lower { 272 panic(ErrTriangle) 273 } 274 } 275 // Extract the lower triangular elements. 276 for i := 0; i < n; i++ { 277 for j := 0; j < i; j++ { 278 dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 279 } 280 } 281 // Set ones on the diagonal. 282 for i := 0; i < n; i++ { 283 dst.mat.Data[i*dst.mat.Stride+i] = 1 284 } 285 return dst 286 } 287 288 // UTo extracts the upper triangular matrix from an LU factorization. 289 // 290 // If dst is empty, UTo will resize dst to be an upper-triangular n×n matrix. 291 // When dst is non-empty, UTo will panic if dst is not n×n or not Upper. 292 // UTo will also panic if the receiver does not contain a successful 293 // factorization. 294 func (lu *LU) UTo(dst *TriDense) { 295 if !lu.isValid() { 296 panic(badLU) 297 } 298 299 _, n := lu.lu.Dims() 300 if dst.IsEmpty() { 301 dst.ReuseAsTri(n, Upper) 302 } else { 303 n2, kind := dst.Triangle() 304 if n != n2 { 305 panic(ErrShape) 306 } 307 if kind != Upper { 308 panic(ErrTriangle) 309 } 310 } 311 // Extract the upper triangular elements. 312 for i := 0; i < n; i++ { 313 for j := i; j < n; j++ { 314 dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 315 } 316 } 317 } 318 319 // Permutation constructs an r×r permutation matrix with the given row swaps. 320 // A permutation matrix has exactly one element equal to one in each row and column 321 // and all other elements equal to zero. swaps[i] specifies the row with which 322 // i will be swapped, which is equivalent to the non-zero column of row i. 323 func (m *Dense) Permutation(r int, swaps []int) { 324 m.reuseAsNonZeroed(r, r) 325 for i := 0; i < r; i++ { 326 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r]) 327 v := swaps[i] 328 if v < 0 || v >= r { 329 panic(ErrRowAccess) 330 } 331 m.mat.Data[i*m.mat.Stride+v] = 1 332 } 333 } 334 335 // SolveTo solves a system of linear equations using the LU decomposition of a matrix. 336 // It computes 337 // A * X = B if trans == false 338 // Aᵀ * X = B if trans == true 339 // In both cases, A is represented in LU factorized form, and the matrix X is 340 // stored into dst. 341 // 342 // If A is singular or near-singular a Condition error is returned. See 343 // the documentation for Condition for more information. 344 // SolveTo will panic if the receiver does not contain a factorization. 345 func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error { 346 if !lu.isValid() { 347 panic(badLU) 348 } 349 350 _, n := lu.lu.Dims() 351 br, bc := b.Dims() 352 if br != n { 353 panic(ErrShape) 354 } 355 // TODO(btracey): Should test the condition number instead of testing that 356 // the determinant is exactly zero. 357 if lu.Det() == 0 { 358 return Condition(math.Inf(1)) 359 } 360 361 dst.reuseAsNonZeroed(n, bc) 362 bU, _ := untranspose(b) 363 var restore func() 364 if dst == bU { 365 dst, restore = dst.isolatedWorkspace(bU) 366 defer restore() 367 } else if rm, ok := bU.(RawMatrixer); ok { 368 dst.checkOverlap(rm.RawMatrix()) 369 } 370 371 dst.Copy(b) 372 t := blas.NoTrans 373 if trans { 374 t = blas.Trans 375 } 376 lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot) 377 if lu.cond > ConditionTolerance { 378 return Condition(lu.cond) 379 } 380 return nil 381 } 382 383 // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix. 384 // It computes 385 // A * x = b if trans == false 386 // Aᵀ * x = b if trans == true 387 // In both cases, A is represented in LU factorized form, and the vector x is 388 // stored into dst. 389 // 390 // If A is singular or near-singular a Condition error is returned. See 391 // the documentation for Condition for more information. 392 // SolveVecTo will panic if the receiver does not contain a factorization. 393 func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 394 if !lu.isValid() { 395 panic(badLU) 396 } 397 398 _, n := lu.lu.Dims() 399 if br, bc := b.Dims(); br != n || bc != 1 { 400 panic(ErrShape) 401 } 402 switch rv := b.(type) { 403 default: 404 dst.reuseAsNonZeroed(n) 405 return lu.SolveTo(dst.asDense(), trans, b) 406 case RawVectorer: 407 if dst != b { 408 dst.checkOverlap(rv.RawVector()) 409 } 410 // TODO(btracey): Should test the condition number instead of testing that 411 // the determinant is exactly zero. 412 if lu.Det() == 0 { 413 return Condition(math.Inf(1)) 414 } 415 416 dst.reuseAsNonZeroed(n) 417 var restore func() 418 if dst == b { 419 dst, restore = dst.isolatedWorkspace(b) 420 defer restore() 421 } 422 dst.CopyVec(b) 423 vMat := blas64.General{ 424 Rows: n, 425 Cols: 1, 426 Stride: dst.mat.Inc, 427 Data: dst.mat.Data, 428 } 429 t := blas.NoTrans 430 if trans { 431 t = blas.Trans 432 } 433 lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) 434 if lu.cond > ConditionTolerance { 435 return Condition(lu.cond) 436 } 437 return nil 438 } 439 }