gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/qr.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  	"gonum.org/v1/gonum/lapack"
    13  	"gonum.org/v1/gonum/lapack/lapack64"
    14  )
    15  
    16  const badQR = "mat: invalid QR factorization"
    17  
    18  // QR is a type for creating and using the QR factorization of a matrix.
    19  type QR struct {
    20  	qr   *Dense
    21  	q    *Dense
    22  	tau  []float64
    23  	cond float64
    24  }
    25  
    26  // Dims returns the dimensions of the matrix.
    27  func (qr *QR) Dims() (r, c int) {
    28  	if qr.qr == nil {
    29  		return 0, 0
    30  	}
    31  	return qr.qr.Dims()
    32  }
    33  
    34  // At returns the element at row i, column j.
    35  func (qr *QR) At(i, j int) float64 {
    36  	m, n := qr.Dims()
    37  	if uint(i) >= uint(m) {
    38  		panic(ErrRowAccess)
    39  	}
    40  	if uint(j) >= uint(n) {
    41  		panic(ErrColAccess)
    42  	}
    43  
    44  	var val float64
    45  	for k := 0; k <= j; k++ {
    46  		val += qr.q.at(i, k) * qr.qr.at(k, j)
    47  	}
    48  	return val
    49  }
    50  
    51  // T performs an implicit transpose by returning the receiver inside a
    52  // Transpose.
    53  func (qr *QR) T() Matrix {
    54  	return Transpose{qr}
    55  }
    56  
    57  func (qr *QR) updateCond(norm lapack.MatrixNorm) {
    58  	// Since A = Q*R, and Q is orthogonal, we get for the condition number κ
    59  	//  κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ|
    60  	//        = |R| |R^-1| = κ(R),
    61  	// where we used that fact that Q^-1 = Qᵀ. However, this assumes that
    62  	// the matrix norm is invariant under orthogonal transformations which
    63  	// is not the case for CondNorm. Hopefully the error is negligible: κ
    64  	// is only a qualitative measure anyway.
    65  	n := qr.qr.mat.Cols
    66  	work := getFloat64s(3*n, false)
    67  	iwork := getInts(n, false)
    68  	r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper)
    69  	v := lapack64.Trcon(norm, r.mat, work, iwork)
    70  	putFloat64s(work)
    71  	putInts(iwork)
    72  	qr.cond = 1 / v
    73  }
    74  
    75  // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR
    76  // factorization always exists even if A is singular.
    77  //
    78  // The QR decomposition is a factorization of the matrix A such that A = Q * R.
    79  // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix.
    80  // Q and R can be extracted using the QTo and RTo methods.
    81  func (qr *QR) Factorize(a Matrix) {
    82  	qr.factorize(a, CondNorm)
    83  }
    84  
    85  func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) {
    86  	m, n := a.Dims()
    87  	if m < n {
    88  		panic(ErrShape)
    89  	}
    90  	if qr.qr == nil {
    91  		qr.qr = &Dense{}
    92  	}
    93  	qr.qr.CloneFrom(a)
    94  	work := []float64{0}
    95  	qr.tau = make([]float64, n)
    96  	lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1)
    97  	work = getFloat64s(int(work[0]), false)
    98  	lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work))
    99  	putFloat64s(work)
   100  	qr.updateCond(norm)
   101  	qr.updateQ()
   102  }
   103  
   104  func (qr *QR) updateQ() {
   105  	m, _ := qr.Dims()
   106  	if qr.q == nil {
   107  		qr.q = NewDense(m, m, nil)
   108  	} else {
   109  		qr.q.reuseAsNonZeroed(m, m)
   110  	}
   111  	// Construct Q from the elementary reflectors.
   112  	qr.q.Copy(qr.qr)
   113  	work := []float64{0}
   114  	lapack64.Orgqr(qr.q.mat, qr.tau, work, -1)
   115  	work = getFloat64s(int(work[0]), false)
   116  	lapack64.Orgqr(qr.q.mat, qr.tau, work, len(work))
   117  	putFloat64s(work)
   118  }
   119  
   120  // isValid returns whether the receiver contains a factorization.
   121  func (qr *QR) isValid() bool {
   122  	return qr.qr != nil && !qr.qr.IsEmpty()
   123  }
   124  
   125  // Cond returns the condition number for the factorized matrix.
   126  // Cond will panic if the receiver does not contain a factorization.
   127  func (qr *QR) Cond() float64 {
   128  	if !qr.isValid() {
   129  		panic(badQR)
   130  	}
   131  	return qr.cond
   132  }
   133  
   134  // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal
   135  // and upper triangular matrices.
   136  
   137  // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition.
   138  //
   139  // If dst is empty, RTo will resize dst to be r×c. When dst is non-empty,
   140  // RTo will panic if dst is not r×c. RTo will also panic if the receiver
   141  // does not contain a successful factorization.
   142  func (qr *QR) RTo(dst *Dense) {
   143  	if !qr.isValid() {
   144  		panic(badQR)
   145  	}
   146  
   147  	r, c := qr.qr.Dims()
   148  	if dst.IsEmpty() {
   149  		dst.ReuseAs(r, c)
   150  	} else {
   151  		r2, c2 := dst.Dims()
   152  		if c != r2 || c != c2 {
   153  			panic(ErrShape)
   154  		}
   155  	}
   156  
   157  	// Disguise the QR as an upper triangular
   158  	t := &TriDense{
   159  		mat: blas64.Triangular{
   160  			N:      c,
   161  			Stride: qr.qr.mat.Stride,
   162  			Data:   qr.qr.mat.Data,
   163  			Uplo:   blas.Upper,
   164  			Diag:   blas.NonUnit,
   165  		},
   166  		cap: qr.qr.capCols,
   167  	}
   168  	dst.Copy(t)
   169  
   170  	// Zero below the triangular.
   171  	for i := r; i < c; i++ {
   172  		zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c])
   173  	}
   174  }
   175  
   176  // QTo extracts the r×r orthonormal matrix Q from a QR decomposition.
   177  //
   178  // If dst is empty, QTo will resize dst to be r×r. When dst is non-empty,
   179  // QTo will panic if dst is not r×r. QTo will also panic if the receiver
   180  // does not contain a successful factorization.
   181  func (qr *QR) QTo(dst *Dense) {
   182  	if !qr.isValid() {
   183  		panic(badQR)
   184  	}
   185  
   186  	r, _ := qr.qr.Dims()
   187  	if dst.IsEmpty() {
   188  		dst.ReuseAs(r, r)
   189  	} else {
   190  		r2, c2 := dst.Dims()
   191  		if r != r2 || r != c2 {
   192  			panic(ErrShape)
   193  		}
   194  	}
   195  	dst.Copy(qr.q)
   196  }
   197  
   198  // SolveTo finds a minimum-norm solution to a system of linear equations defined
   199  // by the matrices A and b, where A is an m×n matrix represented in its QR factorized
   200  // form. If A is singular or near-singular a Condition error is returned.
   201  // See the documentation for Condition for more information.
   202  //
   203  // The minimization problem solved depends on the input parameters.
   204  //
   205  //	If trans == false, find X such that ||A*X - B||_2 is minimized.
   206  //	If trans == true, find the minimum norm solution of Aᵀ * X = B.
   207  //
   208  // The solution matrix, X, is stored in place into dst.
   209  // SolveTo will panic if the receiver does not contain a factorization.
   210  func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error {
   211  	if !qr.isValid() {
   212  		panic(badQR)
   213  	}
   214  
   215  	r, c := qr.qr.Dims()
   216  	br, bc := b.Dims()
   217  
   218  	// The QR solve algorithm stores the result in-place into the right hand side.
   219  	// The storage for the answer must be large enough to hold both b and x.
   220  	// However, this method's receiver must be the size of x. Copy b, and then
   221  	// copy the result into m at the end.
   222  	if trans {
   223  		if c != br {
   224  			panic(ErrShape)
   225  		}
   226  		dst.reuseAsNonZeroed(r, bc)
   227  	} else {
   228  		if r != br {
   229  			panic(ErrShape)
   230  		}
   231  		dst.reuseAsNonZeroed(c, bc)
   232  	}
   233  	// Do not need to worry about overlap between m and b because x has its own
   234  	// independent storage.
   235  	w := getDenseWorkspace(max(r, c), bc, false)
   236  	w.Copy(b)
   237  	t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat
   238  	if trans {
   239  		ok := lapack64.Trtrs(blas.Trans, t, w.mat)
   240  		if !ok {
   241  			return Condition(math.Inf(1))
   242  		}
   243  		for i := c; i < r; i++ {
   244  			zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc])
   245  		}
   246  		work := []float64{0}
   247  		lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1)
   248  		work = getFloat64s(int(work[0]), false)
   249  		lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work))
   250  		putFloat64s(work)
   251  	} else {
   252  		work := []float64{0}
   253  		lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1)
   254  		work = getFloat64s(int(work[0]), false)
   255  		lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work))
   256  		putFloat64s(work)
   257  
   258  		ok := lapack64.Trtrs(blas.NoTrans, t, w.mat)
   259  		if !ok {
   260  			return Condition(math.Inf(1))
   261  		}
   262  	}
   263  	// X was set above to be the correct size for the result.
   264  	dst.Copy(w)
   265  	putDenseWorkspace(w)
   266  	if qr.cond > ConditionTolerance {
   267  		return Condition(qr.cond)
   268  	}
   269  	return nil
   270  }
   271  
   272  // SolveVecTo finds a minimum-norm solution to a system of linear equations,
   273  //
   274  //	Ax = b.
   275  //
   276  // See QR.SolveTo for the full documentation.
   277  // SolveVecTo will panic if the receiver does not contain a factorization.
   278  func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
   279  	if !qr.isValid() {
   280  		panic(badQR)
   281  	}
   282  
   283  	r, c := qr.qr.Dims()
   284  	if _, bc := b.Dims(); bc != 1 {
   285  		panic(ErrShape)
   286  	}
   287  
   288  	// The Solve implementation is non-trivial, so rather than duplicate the code,
   289  	// instead recast the VecDenses as Dense and call the matrix code.
   290  	bm := Matrix(b)
   291  	if rv, ok := b.(RawVectorer); ok {
   292  		bmat := rv.RawVector()
   293  		if dst != b {
   294  			dst.checkOverlap(bmat)
   295  		}
   296  		b := VecDense{mat: bmat}
   297  		bm = b.asDense()
   298  	}
   299  	if trans {
   300  		dst.reuseAsNonZeroed(r)
   301  	} else {
   302  		dst.reuseAsNonZeroed(c)
   303  	}
   304  	return qr.SolveTo(dst.asDense(), trans, bm)
   305  }