github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/lu.go (about) 1 // Copyright ©2013 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat64 6 7 import ( 8 "math" 9 10 "github.com/gonum/blas" 11 "github.com/gonum/blas/blas64" 12 "github.com/gonum/floats" 13 "github.com/gonum/lapack/lapack64" 14 "github.com/gonum/matrix" 15 ) 16 17 const badSliceLength = "mat64: improper slice length" 18 19 // LU is a type for creating and using the LU factorization of a matrix. 20 type LU struct { 21 lu *Dense 22 pivot []int 23 cond float64 24 } 25 26 // updateCond updates the stored condition number of the matrix. Norm is the 27 // norm of the original matrix. If norm is negative it will be estimated. 28 func (lu *LU) updateCond(norm float64) { 29 n := lu.lu.mat.Cols 30 work := make([]float64, 4*n) 31 iwork := make([]int, n) 32 if norm < 0 { 33 // This is an approximation. By the definition of a norm, ||AB|| <= ||A|| ||B||. 34 // The condition number is ||A|| || A^-1||, so this will underestimate 35 // the condition number somewhat. 36 // The norm of the original factorized matrix cannot be stored because of 37 // update possibilities, e.g. RankOne. 38 u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper) 39 l := lu.lu.asTriDense(n, blas.Unit, blas.Lower) 40 unorm := lapack64.Lantr(matrix.CondNorm, u.mat, work) 41 lnorm := lapack64.Lantr(matrix.CondNorm, l.mat, work) 42 norm = unorm * lnorm 43 } 44 v := lapack64.Gecon(matrix.CondNorm, lu.lu.mat, norm, work, iwork) 45 lu.cond = 1 / v 46 } 47 48 // Factorize computes the LU factorization of the square matrix a and stores the 49 // result. The LU decomposition will complete regardless of the singularity of a. 50 // 51 // The LU factorization is computed with pivoting, and so really the decomposition 52 // is a PLU decomposition where P is a permutation matrix. The individual matrix 53 // factors can be extracted from the factorization using the Permutation method 54 // on Dense, and the LFrom and UFrom methods on TriDense. 55 func (lu *LU) Factorize(a Matrix) { 56 r, c := a.Dims() 57 if r != c { 58 panic(matrix.ErrSquare) 59 } 60 if lu.lu == nil { 61 lu.lu = NewDense(r, r, nil) 62 } else { 63 lu.lu.Reset() 64 lu.lu.reuseAs(r, r) 65 } 66 lu.lu.Copy(a) 67 if cap(lu.pivot) < r { 68 lu.pivot = make([]int, r) 69 } 70 lu.pivot = lu.pivot[:r] 71 work := make([]float64, r) 72 anorm := lapack64.Lange(matrix.CondNorm, lu.lu.mat, work) 73 lapack64.Getrf(lu.lu.mat, lu.pivot) 74 lu.updateCond(anorm) 75 } 76 77 // Reset resets the factorization so that it can be reused as the receiver of a 78 // dimensionally restricted operation. 79 func (lu *LU) Reset() { 80 if lu.lu != nil { 81 lu.lu.Reset() 82 } 83 lu.pivot = lu.pivot[:0] 84 } 85 86 func (lu *LU) isZero() bool { 87 return len(lu.pivot) == 0 88 } 89 90 // Det returns the determinant of the matrix that has been factorized. In many 91 // expressions, using LogDet will be more numerically stable. 92 func (lu *LU) Det() float64 { 93 det, sign := lu.LogDet() 94 return math.Exp(det) * sign 95 } 96 97 // LogDet returns the log of the determinant and the sign of the determinant 98 // for the matrix that has been factorized. Numerical stability in product and 99 // division expressions is generally improved by working in log space. 100 func (lu *LU) LogDet() (det float64, sign float64) { 101 _, n := lu.lu.Dims() 102 logDiag := make([]float64, n) 103 sign = 1.0 104 for i := 0; i < n; i++ { 105 v := lu.lu.at(i, i) 106 if v < 0 { 107 sign *= -1 108 } 109 if lu.pivot[i] != i { 110 sign *= -1 111 } 112 logDiag[i] = math.Log(math.Abs(v)) 113 } 114 return floats.Sum(logDiag), sign 115 } 116 117 // Pivot returns pivot indices that enable the construction of the permutation 118 // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be 119 // allocated, otherwise the length of the input must be equal to the size of the 120 // factorized matrix. 121 func (lu *LU) Pivot(swaps []int) []int { 122 _, n := lu.lu.Dims() 123 if swaps == nil { 124 swaps = make([]int, n) 125 } 126 if len(swaps) != n { 127 panic(badSliceLength) 128 } 129 // Perform the inverse of the row swaps in order to find the final 130 // row swap position. 131 for i := range swaps { 132 swaps[i] = i 133 } 134 for i := n - 1; i >= 0; i-- { 135 v := lu.pivot[i] 136 swaps[i], swaps[v] = swaps[v], swaps[i] 137 } 138 return swaps 139 } 140 141 // RankOne updates an LU factorization as if a rank-one update had been applied to 142 // the original matrix A, storing the result into the receiver. That is, if in 143 // the original LU decomposition P * L * U = A, in the updated decomposition 144 // P * L * U = A + alpha * x * y^T. 145 func (lu *LU) RankOne(orig *LU, alpha float64, x, y *Vector) { 146 // RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix 147 // Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng. 148 // http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf 149 _, n := orig.lu.Dims() 150 if x.Len() != n { 151 panic(matrix.ErrShape) 152 } 153 if y.Len() != n { 154 panic(matrix.ErrShape) 155 } 156 if orig != lu { 157 if lu.isZero() { 158 if cap(lu.pivot) < n { 159 lu.pivot = make([]int, n) 160 } 161 lu.pivot = lu.pivot[:n] 162 if lu.lu == nil { 163 lu.lu = NewDense(n, n, nil) 164 } else { 165 lu.lu.reuseAs(n, n) 166 } 167 } else if len(lu.pivot) != n { 168 panic(matrix.ErrShape) 169 } 170 copy(lu.pivot, orig.pivot) 171 lu.lu.Copy(orig.lu) 172 } 173 174 xs := make([]float64, n) 175 ys := make([]float64, n) 176 for i := 0; i < n; i++ { 177 xs[i] = x.at(i) 178 ys[i] = y.at(i) 179 } 180 181 // Adjust for the pivoting in the LU factorization 182 for i, v := range lu.pivot { 183 xs[i], xs[v] = xs[v], xs[i] 184 } 185 186 lum := lu.lu.mat 187 omega := alpha 188 for j := 0; j < n; j++ { 189 ujj := lum.Data[j*lum.Stride+j] 190 ys[j] /= ujj 191 theta := 1 + xs[j]*ys[j]*omega 192 beta := omega * ys[j] / theta 193 gamma := omega * xs[j] 194 omega -= beta * gamma 195 lum.Data[j*lum.Stride+j] *= theta 196 for i := j + 1; i < n; i++ { 197 xs[i] -= lum.Data[i*lum.Stride+j] * xs[j] 198 tmp := ys[i] 199 ys[i] -= lum.Data[j*lum.Stride+i] * ys[j] 200 lum.Data[i*lum.Stride+j] += beta * xs[i] 201 lum.Data[j*lum.Stride+i] += gamma * tmp 202 } 203 } 204 lu.updateCond(-1) 205 } 206 207 // LFromLU extracts the lower triangular matrix from an LU factorization. 208 func (t *TriDense) LFromLU(lu *LU) { 209 _, n := lu.lu.Dims() 210 t.reuseAs(n, false) 211 // Extract the lower triangular elements. 212 for i := 0; i < n; i++ { 213 for j := 0; j < i; j++ { 214 t.mat.Data[i*t.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 215 } 216 } 217 // Set ones on the diagonal. 218 for i := 0; i < n; i++ { 219 t.mat.Data[i*t.mat.Stride+i] = 1 220 } 221 } 222 223 // UFromLU extracts the upper triangular matrix from an LU factorization. 224 func (t *TriDense) UFromLU(lu *LU) { 225 _, n := lu.lu.Dims() 226 t.reuseAs(n, true) 227 // Extract the upper triangular elements. 228 for i := 0; i < n; i++ { 229 for j := i; j < n; j++ { 230 t.mat.Data[i*t.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j] 231 } 232 } 233 } 234 235 // Permutation constructs an r×r permutation matrix with the given row swaps. 236 // A permutation matrix has exactly one element equal to one in each row and column 237 // and all other elements equal to zero. swaps[i] specifies the row with which 238 // i will be swapped, which is equivalent to the non-zero column of row i. 239 func (m *Dense) Permutation(r int, swaps []int) { 240 m.reuseAs(r, r) 241 for i := 0; i < r; i++ { 242 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r]) 243 v := swaps[i] 244 if v < 0 || v >= r { 245 panic(matrix.ErrRowAccess) 246 } 247 m.mat.Data[i*m.mat.Stride+v] = 1 248 } 249 } 250 251 // SolveLU solves a system of linear equations using the LU decomposition of a matrix. 252 // It computes 253 // A * x = b if trans == false 254 // A^T * x = b if trans == true 255 // In both cases, A is represented in LU factorized form, and the matrix x is 256 // stored into the receiver. 257 // 258 // If A is singular or near-singular a Condition error is returned. Please see 259 // the documentation for Condition for more information. 260 func (m *Dense) SolveLU(lu *LU, trans bool, b Matrix) error { 261 _, n := lu.lu.Dims() 262 br, bc := b.Dims() 263 if br != n { 264 panic(matrix.ErrShape) 265 } 266 // TODO(btracey): Should test the condition number instead of testing that 267 // the determinant is exactly zero. 268 if lu.Det() == 0 { 269 return matrix.Condition(math.Inf(1)) 270 } 271 272 m.reuseAs(n, bc) 273 bU, _ := untranspose(b) 274 var restore func() 275 if m == bU { 276 m, restore = m.isolatedWorkspace(bU) 277 defer restore() 278 } else if rm, ok := bU.(RawMatrixer); ok { 279 m.checkOverlap(rm.RawMatrix()) 280 } 281 282 m.Copy(b) 283 t := blas.NoTrans 284 if trans { 285 t = blas.Trans 286 } 287 lapack64.Getrs(t, lu.lu.mat, m.mat, lu.pivot) 288 if lu.cond > matrix.ConditionTolerance { 289 return matrix.Condition(lu.cond) 290 } 291 return nil 292 } 293 294 // SolveLUVec solves a system of linear equations using the LU decomposition of a matrix. 295 // It computes 296 // A * x = b if trans == false 297 // A^T * x = b if trans == true 298 // In both cases, A is represented in LU factorized form, and the matrix x is 299 // stored into the receiver. 300 // 301 // If A is singular or near-singular a Condition error is returned. Please see 302 // the documentation for Condition for more information. 303 func (v *Vector) SolveLUVec(lu *LU, trans bool, b *Vector) error { 304 _, n := lu.lu.Dims() 305 bn := b.Len() 306 if bn != n { 307 panic(matrix.ErrShape) 308 } 309 if v != b { 310 v.checkOverlap(b.mat) 311 } 312 // TODO(btracey): Should test the condition number instead of testing that 313 // the determinant is exactly zero. 314 if lu.Det() == 0 { 315 return matrix.Condition(math.Inf(1)) 316 } 317 318 v.reuseAs(n) 319 var restore func() 320 if v == b { 321 v, restore = v.isolatedWorkspace(b) 322 defer restore() 323 } 324 v.CopyVec(b) 325 vMat := blas64.General{ 326 Rows: n, 327 Cols: 1, 328 Stride: v.mat.Inc, 329 Data: v.mat.Data, 330 } 331 t := blas.NoTrans 332 if trans { 333 t = blas.Trans 334 } 335 lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot) 336 if lu.cond > matrix.ConditionTolerance { 337 return matrix.Condition(lu.cond) 338 } 339 return nil 340 }