gonum.org/v1/gonum@v0.14.0/mat/lu.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/floats" 13 "gonum.org/v1/gonum/lapack" 14 "gonum.org/v1/gonum/lapack/lapack64" 15 ) 16 17 const ( 18 badSliceLength = "mat: improper slice length" 19 badLU = "mat: invalid LU factorization" 20 ) 21 22 // LU is a type for creating and using the LU factorization of a matrix. 23 type LU struct { 24 lu *Dense 25 pivot []int 26 cond float64 27 } 28 29 // updateCond updates the stored condition number of the matrix. anorm is the 30 // norm of the original matrix. If anorm is negative it will be estimated. 31 func (lu *LU) updateCond(anorm float64, norm lapack.MatrixNorm) { 32 n := lu.lu.mat.Cols 33 work := getFloat64s(4*n, false) 34 defer putFloat64s(work) 35 iwork := getInts(n, false) 36 defer putInts(iwork) 37 if anorm < 0 { 38 // This is an approximation. By the definition of a norm, 39 // |AB| <= |A| |B|. 40 // Since A = L*U, we get for the condition number κ that 41 // κ(A) := |A| |A^-1| = |L*U| |A^-1| <= |L| |U| |A^-1|, 42 // so this will overestimate the condition number somewhat. 43 // The norm of the original factorized matrix cannot be stored 44 // because of update possibilities. 45 u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper) 46 l := lu.lu.asTriDense(n, blas.Unit, blas.Lower) 47 unorm := lapack64.Lantr(norm, u.mat, work) 48 lnorm := lapack64.Lantr(norm, l.mat, work) 49 anorm = unorm * lnorm 50 } 51 v := lapack64.Gecon(norm, lu.lu.mat, anorm, work, iwork) 52 lu.cond = 1 / v 53 } 54 55 // Factorize computes the LU factorization of the square matrix a and stores the 56 // result. The LU decomposition will complete regardless of the singularity of a. 57 // 58 // The LU factorization is computed with pivoting, and so really the decomposition 59 // is a PLU decomposition where P is a permutation matrix. The individual matrix 60 // factors can be extracted from the factorization using the Permutation method 61 // on Dense, and the LU.LTo and LU.UTo methods. 62 func (lu *LU) Factorize(a Matrix) { 63 lu.factorize(a, CondNorm) 64 } 65 66 func (lu *LU) factorize(a Matrix, norm lapack.MatrixNorm) { 67 r, c := a.Dims() 68 if r != c { 69 panic(ErrSquare) 70 } 71 if lu.lu == nil { 72 lu.lu = NewDense(r, r, nil) 73 } else { 74 lu.lu.Reset() 75 lu.lu.reuseAsNonZeroed(r, r) 76 } 77 lu.lu.Copy(a) 78 if cap(lu.pivot) < r { 79 lu.pivot = make([]int, r) 80 } 81 lu.pivot = lu.pivot[:r] 82 work := getFloat64s(r, false) 83 anorm := lapack64.Lange(norm, lu.lu.mat, work) 84 putFloat64s(work) 85 lapack64.Getrf(lu.lu.mat, lu.pivot) 86 lu.updateCond(anorm, norm) 87 } 88 89 // isValid returns whether the receiver contains a factorization. 90 func (lu *LU) isValid() bool { 91 return lu.lu != nil && !lu.lu.IsEmpty() 92 } 93 94 // Cond returns the condition number for the factorized matrix. 95 // Cond will panic if the receiver does not contain a factorization. 96 func (lu *LU) Cond() float64 { 97 if !lu.isValid() { 98 panic(badLU) 99 } 100 return lu.cond 101 } 102 103 // Reset resets the factorization so that it can be reused as the receiver of a 104 // dimensionally restricted operation. 105 func (lu *LU) Reset() { 106 if lu.lu != nil { 107 lu.lu.Reset() 108 } 109 lu.pivot = lu.pivot[:0] 110 } 111 112 func (lu *LU) isZero() bool { 113 return len(lu.pivot) == 0 114 } 115 116 // Det returns the determinant of the matrix that has been factorized. In many 117 // expressions, using LogDet will be more numerically stable. 118 // Det will panic if the receiver does not contain a factorization. 119 func (lu *LU) Det() float64 { 120 det, sign := lu.LogDet() 121 return math.Exp(det) * sign 122 } 123 124 // LogDet returns the log of the determinant and the sign of the determinant 125 // for the matrix that has been factorized. Numerical stability in product and 126 // division expressions is generally improved by working in log space. 127 // LogDet will panic if the receiver does not contain a factorization. 128 func (lu *LU) LogDet() (det float64, sign float64) { 129 if !lu.isValid() { 130 panic(badLU) 131 } 132 133 _, n := lu.lu.Dims() 134 logDiag := getFloat64s(n, false) 135 defer putFloat64s(logDiag) 136 sign = 1.0 137 for i := 0; i < n; i++ { 138 v := lu.lu.at(i, i) 139 if v < 0 { 140 sign *= -1 141 } 142 if lu.pivot[i] != i { 143 sign *= -1 144 } 145 logDiag[i] = math.Log(math.Abs(v)) 146 } 147 return floats.Sum(logDiag), sign 148 } 149 150 // Pivot returns pivot indices that enable the construction of the permutation 151 // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be 152 // allocated, otherwise the length of the input must be equal to the size of the 153 // factorized matrix. 154 // Pivot will panic if the receiver does not contain a factorization. 155 func (lu *LU) Pivot(swaps []int) []int { 156 if !lu.isValid() { 157 panic(badLU) 158 } 159 160 _, n := lu.lu.Dims() 161 if swaps == nil { 162 swaps = make([]int, n) 163 } 164 if len(swaps) != n { 165 panic(badSliceLength) 166 } 167 // Perform the inverse of the row swaps in order to find the final 168 // row swap position. 169 for i := range swaps { 170 swaps[i] = i 171 } 172 for i := n - 1; i >= 0; i-- { 173 v := lu.pivot[i] 174 swaps[i], swaps[v] = swaps[v], swaps[i] 175 } 176 return swaps 177 } 178 179 // RankOne updates an LU factorization as if a rank-one update had been applied to 180 // the original matrix A, storing the result into the receiver. That is, if in 181 // the original LU decomposition P * L * U = A, in the updated decomposition 182 // P * L * U = A + alpha * x * yᵀ. 183 // RankOne will panic if orig does not contain a factorization. 184 func (lu *LU) RankOne(orig *LU, alpha float64, x, y Vector) { 185 if !orig.isValid() { 186 panic(badLU) 187 } 188 189 // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix 190 // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng. 191 // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf 192 _, n := orig.lu.Dims() 193 if r, c := x.Dims(); r != n || c != 1 { 194 panic(ErrShape) 195 } 196 if r, c := y.Dims(); r != n || c != 1 { 197 panic(ErrShape) 198 } 199 if orig != lu { 200 if lu.isZero() { 201 if cap(lu.pivot) < n { 202 lu.pivot = make([]int, n) 203 } 204 lu.pivot = lu.pivot[:n] 205 if lu.lu == nil { 206 lu.lu = NewDense(n, n, nil) 207 } else { 208 lu.lu.reuseAsNonZeroed(n, n) 209 } 210 } else if len(lu.pivot) != n { 211 panic(ErrShape) 212 } 213 copy(lu.pivot, orig.pivot) 214 lu.lu.Copy(orig.lu) 215 } 216 217 xs := getFloat64s(n, false) 218 defer putFloat64s(xs) 219 ys := getFloat64s(n, false) 220 defer putFloat64s(ys) 221 for i := 0; i < n; i++ { 222 xs[i] = x.AtVec(i) 223 ys[i] = y.AtVec(i) 224 } 225 226 // Adjust for the pivoting in the LU factorization 227 for i, v := range lu.pivot { 228 xs[i], xs[v] = xs[v], xs[i] 229 } 230 231 lum := lu.lu.mat 232 omega := alpha 233 for j := 0; j < n; j++ { 234 ujj := lum.Data[j*lum.Stride+j] 235 ys[j] /= ujj 236 theta := 1 + xs[j]*ys[j]*omega 237 beta := omega * ys[j] / theta 238 gamma := omega * xs[j] 239 omega -= beta * gamma 240 lum.Data[j*lum.Stride+j] *= theta 241 for i := j + 1; i < n; i++ { 242 xs[i] -= lum.Data[i*lum.Stride+j] * xs[j] 243 tmp := ys[i] 244 ys[i] -= lum.Data[j*lum.Stride+i] * ys[j] 245 lum.Data[i*lum.Stride+j] += beta * xs[i] 246 lum.Data[j*lum.Stride+i] += gamma * tmp 247 } 248 } 249 lu.updateCond(-1, CondNorm) 250 } 251 252 // LTo extracts the lower triangular matrix from an LU factorization. 253 // 254 // If dst is empty, LTo will resize dst to be a lower-triangular n×n matrix. 255 // When dst is non-empty, LTo will panic if dst is not n×n or not Lower. 256 // LTo will also panic if the receiver does not contain a successful 257 // factorization. 258 func (lu *LU) LTo(dst *TriDense) *TriDense { 259 if !lu.isValid() { 260 panic(badLU) 261 } 262 263 _, n := lu.lu.Dims() 264 if dst.IsEmpty() { 265 dst.ReuseAsTri(n, Lower) 266 } else { 267 n2, kind := dst.Triangle() 268 if n != n2 { 269 panic(ErrShape) 270 } 271 if kind != Lower { 272 panic(ErrTriangle) 273 } 274 } 275 // Extract the lower triangular elements. 276 for i := 0; i < n; i++ { 277 for j := 0; j < i; j++ { 278 dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 279 } 280 } 281 // Set ones on the diagonal. 282 for i := 0; i < n; i++ { 283 dst.mat.Data[i*dst.mat.Stride+i] = 1 284 } 285 return dst 286 } 287 288 // UTo extracts the upper triangular matrix from an LU factorization. 289 // 290 // If dst is empty, UTo will resize dst to be an upper-triangular n×n matrix. 291 // When dst is non-empty, UTo will panic if dst is not n×n or not Upper. 292 // UTo will also panic if the receiver does not contain a successful 293 // factorization. 294 func (lu *LU) UTo(dst *TriDense) { 295 if !lu.isValid() { 296 panic(badLU) 297 } 298 299 _, n := lu.lu.Dims() 300 if dst.IsEmpty() { 301 dst.ReuseAsTri(n, Upper) 302 } else { 303 n2, kind := dst.Triangle() 304 if n != n2 { 305 panic(ErrShape) 306 } 307 if kind != Upper { 308 panic(ErrTriangle) 309 } 310 } 311 // Extract the upper triangular elements. 312 for i := 0; i < n; i++ { 313 for j := i; j < n; j++ { 314 dst.mat.Data[i*dst.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 315 } 316 } 317 } 318 319 // Permutation constructs an r×r permutation matrix with the given row swaps. 320 // A permutation matrix has exactly one element equal to one in each row and column 321 // and all other elements equal to zero. swaps[i] specifies the row with which 322 // i will be swapped, which is equivalent to the non-zero column of row i. 323 func (m *Dense) Permutation(r int, swaps []int) { 324 m.reuseAsNonZeroed(r, r) 325 for i := 0; i < r; i++ { 326 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r]) 327 v := swaps[i] 328 if v < 0 || v >= r { 329 panic(ErrRowAccess) 330 } 331 m.mat.Data[i*m.mat.Stride+v] = 1 332 } 333 } 334 335 // SolveTo solves a system of linear equations using the LU decomposition of a matrix. 336 // It computes 337 // 338 // A * X = B if trans == false 339 // Aᵀ * X = B if trans == true 340 // 341 // In both cases, A is represented in LU factorized form, and the matrix X is 342 // stored into dst. 343 // 344 // If A is singular or near-singular a Condition error is returned. See 345 // the documentation for Condition for more information. 346 // SolveTo will panic if the receiver does not contain a factorization. 347 func (lu *LU) SolveTo(dst *Dense, trans bool, b Matrix) error { 348 if !lu.isValid() { 349 panic(badLU) 350 } 351 352 _, n := lu.lu.Dims() 353 br, bc := b.Dims() 354 if br != n { 355 panic(ErrShape) 356 } 357 // TODO(btracey): Should test the condition number instead of testing that 358 // the determinant is exactly zero. 359 if lu.Det() == 0 { 360 return Condition(math.Inf(1)) 361 } 362 363 dst.reuseAsNonZeroed(n, bc) 364 bU, _ := untranspose(b) 365 var restore func() 366 if dst == bU { 367 dst, restore = dst.isolatedWorkspace(bU) 368 defer restore() 369 } else if rm, ok := bU.(RawMatrixer); ok { 370 dst.checkOverlap(rm.RawMatrix()) 371 } 372 373 dst.Copy(b) 374 t := blas.NoTrans 375 if trans { 376 t = blas.Trans 377 } 378 lapack64.Getrs(t, lu.lu.mat, dst.mat, lu.pivot) 379 if lu.cond > ConditionTolerance { 380 return Condition(lu.cond) 381 } 382 return nil 383 } 384 385 // SolveVecTo solves a system of linear equations using the LU decomposition of a matrix. 386 // It computes 387 // 388 // A * x = b if trans == false 389 // Aᵀ * x = b if trans == true 390 // 391 // In both cases, A is represented in LU factorized form, and the vector x is 392 // stored into dst. 393 // 394 // If A is singular or near-singular a Condition error is returned. See 395 // the documentation for Condition for more information. 396 // SolveVecTo will panic if the receiver does not contain a factorization. 397 func (lu *LU) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 398 if !lu.isValid() { 399 panic(badLU) 400 } 401 402 _, n := lu.lu.Dims() 403 if br, bc := b.Dims(); br != n || bc != 1 { 404 panic(ErrShape) 405 } 406 switch rv := b.(type) { 407 default: 408 dst.reuseAsNonZeroed(n) 409 return lu.SolveTo(dst.asDense(), trans, b) 410 case RawVectorer: 411 if dst != b { 412 dst.checkOverlap(rv.RawVector()) 413 } 414 // TODO(btracey): Should test the condition number instead of testing that 415 // the determinant is exactly zero. 416 if lu.Det() == 0 { 417 return Condition(math.Inf(1)) 418 } 419 420 dst.reuseAsNonZeroed(n) 421 var restore func() 422 if dst == b { 423 dst, restore = dst.isolatedWorkspace(b) 424 defer restore() 425 } 426 dst.CopyVec(b) 427 vMat := blas64.General{ 428 Rows: n, 429 Cols: 1, 430 Stride: dst.mat.Inc, 431 Data: dst.mat.Data, 432 } 433 t := blas.NoTrans 434 if trans { 435 t = blas.Trans 436 } 437 lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) 438 if lu.cond > ConditionTolerance { 439 return Condition(lu.cond) 440 } 441 return nil 442 } 443 }