github.com/gopherd/gonum@v0.0.4/mat/qr.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  	"github.com/gopherd/gonum/lapack"
    13  	"github.com/gopherd/gonum/lapack/lapack64"
    14  )
    15  
    16  const badQR = "mat: invalid QR factorization"
    17  
    18  // QR is a type for creating and using the QR factorization of a matrix.
    19  type QR struct {
    20  	qr   *Dense
    21  	tau  []float64
    22  	cond float64
    23  }
    24  
    25  func (qr *QR) updateCond(norm lapack.MatrixNorm) {
    26  	// Since A = Q*R, and Q is orthogonal, we get for the condition number κ
    27  	//  κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ|
    28  	//        = |R| |R^-1| = κ(R),
    29  	// where we used that fact that Q^-1 = Qᵀ. However, this assumes that
    30  	// the matrix norm is invariant under orthogonal transformations which
    31  	// is not the case for CondNorm. Hopefully the error is negligible: κ
    32  	// is only a qualitative measure anyway.
    33  	n := qr.qr.mat.Cols
    34  	work := getFloat64s(3*n, false)
    35  	iwork := getInts(n, false)
    36  	r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper)
    37  	v := lapack64.Trcon(norm, r.mat, work, iwork)
    38  	putFloat64s(work)
    39  	putInts(iwork)
    40  	qr.cond = 1 / v
    41  }
    42  
    43  // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR
    44  // factorization always exists even if A is singular.
    45  //
    46  // The QR decomposition is a factorization of the matrix A such that A = Q * R.
    47  // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix.
    48  // Q and R can be extracted using the QTo and RTo methods.
    49  func (qr *QR) Factorize(a Matrix) {
    50  	qr.factorize(a, CondNorm)
    51  }
    52  
    53  func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
    54  	m, n := a.Dims()
    55  	if m < n {
    56  		panic(ErrShape)
    57  	}
    58  	k := min(m, n)
    59  	if qr.qr == nil {
    60  		qr.qr = &Dense{}
    61  	}
    62  	qr.qr.CloneFrom(a)
    63  	work := []float64{0}
    64  	qr.tau = make([]float64, k)
    65  	lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1)
    66  	work = getFloat64s(int(work[0]), false)
    67  	lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
    68  	putFloat64s(work)
    69  	qr.updateCond(norm)
    70  }
    71  
    72  // isValid returns whether the receiver contains a factorization.
    73  func (qr *QR) isValid() bool {
    74  	return qr.qr != nil && !qr.qr.IsEmpty()
    75  }
    76  
    77  // Cond returns the condition number for the factorized matrix.
    78  // Cond will panic if the receiver does not contain a factorization.
    79  func (qr *QR) Cond() float64 {
    80  	if !qr.isValid() {
    81  		panic(badQR)
    82  	}
    83  	return qr.cond
    84  }
    85  
    86  // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal
    87  // and upper triangular matrices.
    88  
    89  // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition.
    90  //
    91  // If dst is empty, RTo will resize dst to be r×c. When dst is non-empty,
    92  // RTo will panic if dst is not r×c. RTo will also panic if the receiver
    93  // does not contain a successful factorization.
    94  func (qr *QR) RTo(dst *Dense) {
    95  	if !qr.isValid() {
    96  		panic(badQR)
    97  	}
    98  
    99  	r, c := qr.qr.Dims()
   100  	if dst.IsEmpty() {
   101  		dst.ReuseAs(r, c)
   102  	} else {
   103  		r2, c2 := dst.Dims()
   104  		if c != r2 || c != c2 {
   105  			panic(ErrShape)
   106  		}
   107  	}
   108  
   109  	// Disguise the QR as an upper triangular
   110  	t := &TriDense{
   111  		mat: blas64.Triangular{
   112  			N:      c,
   113  			Stride: qr.qr.mat.Stride,
   114  			Data:   qr.qr.mat.Data,
   115  			Uplo:   blas.Upper,
   116  			Diag:   blas.NonUnit,
   117  		},
   118  		cap: qr.qr.capCols,
   119  	}
   120  	dst.Copy(t)
   121  
   122  	// Zero below the triangular.
   123  	for i := r; i < c; i++ {
   124  		zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c])
   125  	}
   126  }
   127  
   128  // QTo extracts the r×r orthonormal matrix Q from a QR decomposition.
   129  //
   130  // If dst is empty, QTo will resize dst to be r×r. When dst is non-empty,
   131  // QTo will panic if dst is not r×r. QTo will also panic if the receiver
   132  // does not contain a successful factorization.
   133  func (qr *QR) QTo(dst *Dense) {
   134  	if !qr.isValid() {
   135  		panic(badQR)
   136  	}
   137  
   138  	r, _ := qr.qr.Dims()
   139  	if dst.IsEmpty() {
   140  		dst.ReuseAs(r, r)
   141  	} else {
   142  		r2, c2 := dst.Dims()
   143  		if r != r2 || r != c2 {
   144  			panic(ErrShape)
   145  		}
   146  		dst.Zero()
   147  	}
   148  
   149  	// Set Q = I.
   150  	for i := 0; i < r*r; i += r + 1 {
   151  		dst.mat.Data[i] = 1
   152  	}
   153  
   154  	// Construct Q from the elementary reflectors.
   155  	work := []float64{0}
   156  	lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1)
   157  	work = getFloat64s(int(work[0]), false)
   158  	lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work))
   159  	putFloat64s(work)
   160  }
   161  
   162  // SolveTo finds a minimum-norm solution to a system of linear equations defined
   163  // by the matrices A and b, where A is an m×n matrix represented in its QR factorized
   164  // form. If A is singular or near-singular a Condition error is returned.
   165  // See the documentation for Condition for more information.
   166  //
   167  // The minimization problem solved depends on the input parameters.
   168  //  If trans == false, find X such that ||A*X - B||_2 is minimized.
   169  //  If trans == true, find the minimum norm solution of Aᵀ * X = B.
   170  // The solution matrix, X, is stored in place into dst.
   171  // SolveTo will panic if the receiver does not contain a factorization.
   172  func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error {
   173  	if !qr.isValid() {
   174  		panic(badQR)
   175  	}
   176  
   177  	r, c := qr.qr.Dims()
   178  	br, bc := b.Dims()
   179  
   180  	// The QR solve algorithm stores the result in-place into the right hand side.
   181  	// The storage for the answer must be large enough to hold both b and x.
   182  	// However, this method's receiver must be the size of x. Copy b, and then
   183  	// copy the result into m at the end.
   184  	if trans {
   185  		if c != br {
   186  			panic(ErrShape)
   187  		}
   188  		dst.reuseAsNonZeroed(r, bc)
   189  	} else {
   190  		if r != br {
   191  			panic(ErrShape)
   192  		}
   193  		dst.reuseAsNonZeroed(c, bc)
   194  	}
   195  	// Do not need to worry about overlap between m and b because x has its own
   196  	// independent storage.
   197  	w := getDenseWorkspace(max(r, c), bc, false)
   198  	w.Copy(b)
   199  	t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat
   200  	if trans {
   201  		ok := lapack64.Trtrs(blas.Trans, t, w.mat)
   202  		if !ok {
   203  			return Condition(math.Inf(1))
   204  		}
   205  		for i := c; i < r; i++ {
   206  			zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
   207  		}
   208  		work := []float64{0}
   209  		lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1)
   210  		work = getFloat64s(int(work[0]), false)
   211  		lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work))
   212  		putFloat64s(work)
   213  	} else {
   214  		work := []float64{0}
   215  		lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1)
   216  		work = getFloat64s(int(work[0]), false)
   217  		lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work))
   218  		putFloat64s(work)
   219  
   220  		ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
   221  		if !ok {
   222  			return Condition(math.Inf(1))
   223  		}
   224  	}
   225  	// X was set above to be the correct size for the result.
   226  	dst.Copy(w)
   227  	putDenseWorkspace(w)
   228  	if qr.cond > ConditionTolerance {
   229  		return Condition(qr.cond)
   230  	}
   231  	return nil
   232  }
   233  
   234  // SolveVecTo finds a minimum-norm solution to a system of linear equations,
   235  //  Ax = b.
   236  // See QR.SolveTo for the full documentation.
   237  // SolveVecTo will panic if the receiver does not contain a factorization.
   238  func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
   239  	if !qr.isValid() {
   240  		panic(badQR)
   241  	}
   242  
   243  	r, c := qr.qr.Dims()
   244  	if _, bc := b.Dims(); bc != 1 {
   245  		panic(ErrShape)
   246  	}
   247  
   248  	// The Solve implementation is non-trivial, so rather than duplicate the code,
   249  	// instead recast the VecDenses as Dense and call the matrix code.
   250  	bm := Matrix(b)
   251  	if rv, ok := b.(RawVectorer); ok {
   252  		bmat := rv.RawVector()
   253  		if dst != b {
   254  			dst.checkOverlap(bmat)
   255  		}
   256  		b := VecDense{mat: bmat}
   257  		bm = b.asDense()
   258  	}
   259  	if trans {
   260  		dst.reuseAsNonZeroed(r)
   261  	} else {
   262  		dst.reuseAsNonZeroed(c)
   263  	}
   264  	return qr.SolveTo(dst.asDense(), trans, bm)
   265  }