github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/lu.go (about)

     1  // Copyright ©2013 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat64
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gonum/blas"
    11  	"github.com/gonum/blas/blas64"
    12  	"github.com/gonum/floats"
    13  	"github.com/gonum/lapack/lapack64"
    14  	"github.com/gonum/matrix"
    15  )
    16  
    17  const badSliceLength = "mat64: improper slice length"
    18  
    19  // LU is a type for creating and using the LU factorization of a matrix.
    20  type LU struct {
    21  	lu    *Dense
    22  	pivot []int
    23  	cond  float64
    24  }
    25  
    26  // updateCond updates the stored condition number of the matrix. Norm is the
    27  // norm of the original matrix. If norm is negative it will be estimated.
    28  func (lu *LU) updateCond(norm float64) {
    29  	n := lu.lu.mat.Cols
    30  	work := make([]float64, 4*n)
    31  	iwork := make([]int, n)
    32  	if norm < 0 {
    33  		// This is an approximation. By the definition of a norm, ||AB|| <= ||A|| ||B||.
    34  		// The condition number is ||A|| || A^-1||, so this will underestimate
    35  		// the condition number somewhat.
    36  		// The norm of the original factorized matrix cannot be stored because of
    37  		// update possibilities, e.g. RankOne.
    38  		u := lu.lu.asTriDense(n, blas.NonUnit, blas.Upper)
    39  		l := lu.lu.asTriDense(n, blas.Unit, blas.Lower)
    40  		unorm := lapack64.Lantr(matrix.CondNorm, u.mat, work)
    41  		lnorm := lapack64.Lantr(matrix.CondNorm, l.mat, work)
    42  		norm = unorm * lnorm
    43  	}
    44  	v := lapack64.Gecon(matrix.CondNorm, lu.lu.mat, norm, work, iwork)
    45  	lu.cond = 1 / v
    46  }
    47  
    48  // Factorize computes the LU factorization of the square matrix a and stores the
    49  // result. The LU decomposition will complete regardless of the singularity of a.
    50  //
    51  // The LU factorization is computed with pivoting, and so really the decomposition
    52  // is a PLU decomposition where P is a permutation matrix. The individual matrix
    53  // factors can be extracted from the factorization using the Permutation method
    54  // on Dense, and the LFrom and UFrom methods on TriDense.
    55  func (lu *LU) Factorize(a Matrix) {
    56  	r, c := a.Dims()
    57  	if r != c {
    58  		panic(matrix.ErrSquare)
    59  	}
    60  	if lu.lu == nil {
    61  		lu.lu = NewDense(r, r, nil)
    62  	} else {
    63  		lu.lu.Reset()
    64  		lu.lu.reuseAs(r, r)
    65  	}
    66  	lu.lu.Copy(a)
    67  	if cap(lu.pivot) < r {
    68  		lu.pivot = make([]int, r)
    69  	}
    70  	lu.pivot = lu.pivot[:r]
    71  	work := make([]float64, r)
    72  	anorm := lapack64.Lange(matrix.CondNorm, lu.lu.mat, work)
    73  	lapack64.Getrf(lu.lu.mat, lu.pivot)
    74  	lu.updateCond(anorm)
    75  }
    76  
    77  // Reset resets the factorization so that it can be reused as the receiver of a
    78  // dimensionally restricted operation.
    79  func (lu *LU) Reset() {
    80  	if lu.lu != nil {
    81  		lu.lu.Reset()
    82  	}
    83  	lu.pivot = lu.pivot[:0]
    84  }
    85  
    86  func (lu *LU) isZero() bool {
    87  	return len(lu.pivot) == 0
    88  }
    89  
    90  // Det returns the determinant of the matrix that has been factorized. In many
    91  // expressions, using LogDet will be more numerically stable.
    92  func (lu *LU) Det() float64 {
    93  	det, sign := lu.LogDet()
    94  	return math.Exp(det) * sign
    95  }
    96  
    97  // LogDet returns the log of the determinant and the sign of the determinant
    98  // for the matrix that has been factorized. Numerical stability in product and
    99  // division expressions is generally improved by working in log space.
   100  func (lu *LU) LogDet() (det float64, sign float64) {
   101  	_, n := lu.lu.Dims()
   102  	logDiag := make([]float64, n)
   103  	sign = 1.0
   104  	for i := 0; i < n; i++ {
   105  		v := lu.lu.at(i, i)
   106  		if v < 0 {
   107  			sign *= -1
   108  		}
   109  		if lu.pivot[i] != i {
   110  			sign *= -1
   111  		}
   112  		logDiag[i] = math.Log(math.Abs(v))
   113  	}
   114  	return floats.Sum(logDiag), sign
   115  }
   116  
   117  // Pivot returns pivot indices that enable the construction of the permutation
   118  // matrix P (see Dense.Permutation). If swaps == nil, then new memory will be
   119  // allocated, otherwise the length of the input must be equal to the size of the
   120  // factorized matrix.
   121  func (lu *LU) Pivot(swaps []int) []int {
   122  	_, n := lu.lu.Dims()
   123  	if swaps == nil {
   124  		swaps = make([]int, n)
   125  	}
   126  	if len(swaps) != n {
   127  		panic(badSliceLength)
   128  	}
   129  	// Perform the inverse of the row swaps in order to find the final
   130  	// row swap position.
   131  	for i := range swaps {
   132  		swaps[i] = i
   133  	}
   134  	for i := n - 1; i >= 0; i-- {
   135  		v := lu.pivot[i]
   136  		swaps[i], swaps[v] = swaps[v], swaps[i]
   137  	}
   138  	return swaps
   139  }
   140  
   141  // RankOne updates an LU factorization as if a rank-one update had been applied to
   142  // the original matrix A, storing the result into the receiver. That is, if in
   143  // the original LU decomposition P * L * U = A, in the updated decomposition
   144  // P * L * U = A + alpha * x * y^T.
   145  func (lu *LU) RankOne(orig *LU, alpha float64, x, y *Vector) {
   146  	// RankOne uses algorithm a1 on page 28 of "Multiple-Rank Updates to Matrix
   147  	// Factorizations for Nonlinear Analysis and Circuit Design" by Linzhong Deng.
   148  	// http://web.stanford.edu/group/SOL/dissertations/Linzhong-Deng-thesis.pdf
   149  	_, n := orig.lu.Dims()
   150  	if x.Len() != n {
   151  		panic(matrix.ErrShape)
   152  	}
   153  	if y.Len() != n {
   154  		panic(matrix.ErrShape)
   155  	}
   156  	if orig != lu {
   157  		if lu.isZero() {
   158  			if cap(lu.pivot) < n {
   159  				lu.pivot = make([]int, n)
   160  			}
   161  			lu.pivot = lu.pivot[:n]
   162  			if lu.lu == nil {
   163  				lu.lu = NewDense(n, n, nil)
   164  			} else {
   165  				lu.lu.reuseAs(n, n)
   166  			}
   167  		} else if len(lu.pivot) != n {
   168  			panic(matrix.ErrShape)
   169  		}
   170  		copy(lu.pivot, orig.pivot)
   171  		lu.lu.Copy(orig.lu)
   172  	}
   173  
   174  	xs := make([]float64, n)
   175  	ys := make([]float64, n)
   176  	for i := 0; i < n; i++ {
   177  		xs[i] = x.at(i)
   178  		ys[i] = y.at(i)
   179  	}
   180  
   181  	// Adjust for the pivoting in the LU factorization
   182  	for i, v := range lu.pivot {
   183  		xs[i], xs[v] = xs[v], xs[i]
   184  	}
   185  
   186  	lum := lu.lu.mat
   187  	omega := alpha
   188  	for j := 0; j < n; j++ {
   189  		ujj := lum.Data[j*lum.Stride+j]
   190  		ys[j] /= ujj
   191  		theta := 1 + xs[j]*ys[j]*omega
   192  		beta := omega * ys[j] / theta
   193  		gamma := omega * xs[j]
   194  		omega -= beta * gamma
   195  		lum.Data[j*lum.Stride+j] *= theta
   196  		for i := j + 1; i < n; i++ {
   197  			xs[i] -= lum.Data[i*lum.Stride+j] * xs[j]
   198  			tmp := ys[i]
   199  			ys[i] -= lum.Data[j*lum.Stride+i] * ys[j]
   200  			lum.Data[i*lum.Stride+j] += beta * xs[i]
   201  			lum.Data[j*lum.Stride+i] += gamma * tmp
   202  		}
   203  	}
   204  	lu.updateCond(-1)
   205  }
   206  
   207  // LFromLU extracts the lower triangular matrix from an LU factorization.
   208  func (t *TriDense) LFromLU(lu *LU) {
   209  	_, n := lu.lu.Dims()
   210  	t.reuseAs(n, false)
   211  	// Extract the lower triangular elements.
   212  	for i := 0; i < n; i++ {
   213  		for j := 0; j < i; j++ {
   214  			t.mat.Data[i*t.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
   215  		}
   216  	}
   217  	// Set ones on the diagonal.
   218  	for i := 0; i < n; i++ {
   219  		t.mat.Data[i*t.mat.Stride+i] = 1
   220  	}
   221  }
   222  
   223  // UFromLU extracts the upper triangular matrix from an LU factorization.
   224  func (t *TriDense) UFromLU(lu *LU) {
   225  	_, n := lu.lu.Dims()
   226  	t.reuseAs(n, true)
   227  	// Extract the upper triangular elements.
   228  	for i := 0; i < n; i++ {
   229  		for j := i; j < n; j++ {
   230  			t.mat.Data[i*t.mat.Stride+j] = lu.lu.mat.Data[i*lu.lu.mat.Stride+j]
   231  		}
   232  	}
   233  }
   234  
   235  // Permutation constructs an r×r permutation matrix with the given row swaps.
   236  // A permutation matrix has exactly one element equal to one in each row and column
   237  // and all other elements equal to zero. swaps[i] specifies the row with which
   238  // i will be swapped, which is equivalent to the non-zero column of row i.
   239  func (m *Dense) Permutation(r int, swaps []int) {
   240  	m.reuseAs(r, r)
   241  	for i := 0; i < r; i++ {
   242  		zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+r])
   243  		v := swaps[i]
   244  		if v < 0 || v >= r {
   245  			panic(matrix.ErrRowAccess)
   246  		}
   247  		m.mat.Data[i*m.mat.Stride+v] = 1
   248  	}
   249  }
   250  
   251  // SolveLU solves a system of linear equations using the LU decomposition of a matrix.
   252  // It computes
   253  //  A * x = b if trans == false
   254  //  A^T * x = b if trans == true
   255  // In both cases, A is represented in LU factorized form, and the matrix x is
   256  // stored into the receiver.
   257  //
   258  // If A is singular or near-singular a Condition error is returned. Please see
   259  // the documentation for Condition for more information.
   260  func (m *Dense) SolveLU(lu *LU, trans bool, b Matrix) error {
   261  	_, n := lu.lu.Dims()
   262  	br, bc := b.Dims()
   263  	if br != n {
   264  		panic(matrix.ErrShape)
   265  	}
   266  	// TODO(btracey): Should test the condition number instead of testing that
   267  	// the determinant is exactly zero.
   268  	if lu.Det() == 0 {
   269  		return matrix.Condition(math.Inf(1))
   270  	}
   271  
   272  	m.reuseAs(n, bc)
   273  	bU, _ := untranspose(b)
   274  	var restore func()
   275  	if m == bU {
   276  		m, restore = m.isolatedWorkspace(bU)
   277  		defer restore()
   278  	} else if rm, ok := bU.(RawMatrixer); ok {
   279  		m.checkOverlap(rm.RawMatrix())
   280  	}
   281  
   282  	m.Copy(b)
   283  	t := blas.NoTrans
   284  	if trans {
   285  		t = blas.Trans
   286  	}
   287  	lapack64.Getrs(t, lu.lu.mat, m.mat, lu.pivot)
   288  	if lu.cond > matrix.ConditionTolerance {
   289  		return matrix.Condition(lu.cond)
   290  	}
   291  	return nil
   292  }
   293  
   294  // SolveLUVec solves a system of linear equations using the LU decomposition of a matrix.
   295  // It computes
   296  //  A * x = b if trans == false
   297  //  A^T * x = b if trans == true
   298  // In both cases, A is represented in LU factorized form, and the matrix x is
   299  // stored into the receiver.
   300  //
   301  // If A is singular or near-singular a Condition error is returned. Please see
   302  // the documentation for Condition for more information.
   303  func (v *Vector) SolveLUVec(lu *LU, trans bool, b *Vector) error {
   304  	_, n := lu.lu.Dims()
   305  	bn := b.Len()
   306  	if bn != n {
   307  		panic(matrix.ErrShape)
   308  	}
   309  	if v != b {
   310  		v.checkOverlap(b.mat)
   311  	}
   312  	// TODO(btracey): Should test the condition number instead of testing that
   313  	// the determinant is exactly zero.
   314  	if lu.Det() == 0 {
   315  		return matrix.Condition(math.Inf(1))
   316  	}
   317  
   318  	v.reuseAs(n)
   319  	var restore func()
   320  	if v == b {
   321  		v, restore = v.isolatedWorkspace(b)
   322  		defer restore()
   323  	}
   324  	v.CopyVec(b)
   325  	vMat := blas64.General{
   326  		Rows:   n,
   327  		Cols:   1,
   328  		Stride: v.mat.Inc,
   329  		Data:   v.mat.Data,
   330  	}
   331  	t := blas.NoTrans
   332  	if trans {
   333  		t = blas.Trans
   334  	}
   335  	lapack64.Getrs(t, lu.lu.mat, vMat, lu.pivot)
   336  	if lu.cond > matrix.ConditionTolerance {
   337  		return matrix.Condition(lu.cond)
   338  	}
   339  	return nil
   340  }