gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/qr.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/lapack" 13 "gonum.org/v1/gonum/lapack/lapack64" 14 ) 15 16 const badQR = "mat: invalid QR factorization" 17 18 // QR is a type for creating and using the QR factorization of a matrix. 19 type QR struct { 20 qr *Dense 21 q *Dense 22 tau []float64 23 cond float64 24 } 25 26 // Dims returns the dimensions of the matrix. 27 func (qr *QR) Dims() (r, c int) { 28 if qr.qr == nil { 29 return 0, 0 30 } 31 return qr.qr.Dims() 32 } 33 34 // At returns the element at row i, column j. 35 func (qr *QR) At(i, j int) float64 { 36 m, n := qr.Dims() 37 if uint(i) >= uint(m) { 38 panic(ErrRowAccess) 39 } 40 if uint(j) >= uint(n) { 41 panic(ErrColAccess) 42 } 43 44 var val float64 45 for k := 0; k <= j; k++ { 46 val += qr.q.at(i, k) * qr.qr.at(k, j) 47 } 48 return val 49 } 50 51 // T performs an implicit transpose by returning the receiver inside a 52 // Transpose. 53 func (qr *QR) T() Matrix { 54 return Transpose{qr} 55 } 56 57 func (qr *QR) updateCond(norm lapack.MatrixNorm) { 58 // Since A = Q*R, and Q is orthogonal, we get for the condition number κ 59 // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ| 60 // = |R| |R^-1| = κ(R), 61 // where we used that fact that Q^-1 = Qᵀ. However, this assumes that 62 // the matrix norm is invariant under orthogonal transformations which 63 // is not the case for CondNorm. Hopefully the error is negligible: κ 64 // is only a qualitative measure anyway. 65 n := qr.qr.mat.Cols 66 work := getFloat64s(3*n, false) 67 iwork := getInts(n, false) 68 r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper) 69 v := lapack64.Trcon(norm, r.mat, work, iwork) 70 putFloat64s(work) 71 putInts(iwork) 72 qr.cond = 1 / v 73 } 74 75 // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR 76 // factorization always exists even if A is singular. 77 // 78 // The QR decomposition is a factorization of the matrix A such that A = Q * R. 79 // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix. 80 // Q and R can be extracted using the QTo and RTo methods. 81 func (qr *QR) Factorize(a Matrix) { 82 qr.factorize(a, CondNorm) 83 } 84 85 func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) { 86 m, n := a.Dims() 87 if m < n { 88 panic(ErrShape) 89 } 90 if qr.qr == nil { 91 qr.qr = &Dense{} 92 } 93 qr.qr.CloneFrom(a) 94 work := []float64{0} 95 qr.tau = make([]float64, n) 96 lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1) 97 work = getFloat64s(int(work[0]), false) 98 lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work)) 99 putFloat64s(work) 100 qr.updateCond(norm) 101 qr.updateQ() 102 } 103 104 func (qr *QR) updateQ() { 105 m, _ := qr.Dims() 106 if qr.q == nil { 107 qr.q = NewDense(m, m, nil) 108 } else { 109 qr.q.reuseAsNonZeroed(m, m) 110 } 111 // Construct Q from the elementary reflectors. 112 qr.q.Copy(qr.qr) 113 work := []float64{0} 114 lapack64.Orgqr(qr.q.mat, qr.tau, work, -1) 115 work = getFloat64s(int(work[0]), false) 116 lapack64.Orgqr(qr.q.mat, qr.tau, work, len(work)) 117 putFloat64s(work) 118 } 119 120 // isValid returns whether the receiver contains a factorization. 121 func (qr *QR) isValid() bool { 122 return qr.qr != nil && !qr.qr.IsEmpty() 123 } 124 125 // Cond returns the condition number for the factorized matrix. 126 // Cond will panic if the receiver does not contain a factorization. 127 func (qr *QR) Cond() float64 { 128 if !qr.isValid() { 129 panic(badQR) 130 } 131 return qr.cond 132 } 133 134 // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal 135 // and upper triangular matrices. 136 137 // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition. 138 // 139 // If dst is empty, RTo will resize dst to be r×c. When dst is non-empty, 140 // RTo will panic if dst is not r×c. RTo will also panic if the receiver 141 // does not contain a successful factorization. 142 func (qr *QR) RTo(dst *Dense) { 143 if !qr.isValid() { 144 panic(badQR) 145 } 146 147 r, c := qr.qr.Dims() 148 if dst.IsEmpty() { 149 dst.ReuseAs(r, c) 150 } else { 151 r2, c2 := dst.Dims() 152 if c != r2 || c != c2 { 153 panic(ErrShape) 154 } 155 } 156 157 // Disguise the QR as an upper triangular 158 t := &TriDense{ 159 mat: blas64.Triangular{ 160 N: c, 161 Stride: qr.qr.mat.Stride, 162 Data: qr.qr.mat.Data, 163 Uplo: blas.Upper, 164 Diag: blas.NonUnit, 165 }, 166 cap: qr.qr.capCols, 167 } 168 dst.Copy(t) 169 170 // Zero below the triangular. 171 for i := r; i < c; i++ { 172 zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c]) 173 } 174 } 175 176 // QTo extracts the r×r orthonormal matrix Q from a QR decomposition. 177 // 178 // If dst is empty, QTo will resize dst to be r×r. When dst is non-empty, 179 // QTo will panic if dst is not r×r. QTo will also panic if the receiver 180 // does not contain a successful factorization. 181 func (qr *QR) QTo(dst *Dense) { 182 if !qr.isValid() { 183 panic(badQR) 184 } 185 186 r, _ := qr.qr.Dims() 187 if dst.IsEmpty() { 188 dst.ReuseAs(r, r) 189 } else { 190 r2, c2 := dst.Dims() 191 if r != r2 || r != c2 { 192 panic(ErrShape) 193 } 194 } 195 dst.Copy(qr.q) 196 } 197 198 // SolveTo finds a minimum-norm solution to a system of linear equations defined 199 // by the matrices A and b, where A is an m×n matrix represented in its QR factorized 200 // form. If A is singular or near-singular a Condition error is returned. 201 // See the documentation for Condition for more information. 202 // 203 // The minimization problem solved depends on the input parameters. 204 // 205 // If trans == false, find X such that ||A*X - B||_2 is minimized. 206 // If trans == true, find the minimum norm solution of Aᵀ * X = B. 207 // 208 // The solution matrix, X, is stored in place into dst. 209 // SolveTo will panic if the receiver does not contain a factorization. 210 func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error { 211 if !qr.isValid() { 212 panic(badQR) 213 } 214 215 r, c := qr.qr.Dims() 216 br, bc := b.Dims() 217 218 // The QR solve algorithm stores the result in-place into the right hand side. 219 // The storage for the answer must be large enough to hold both b and x. 220 // However, this method's receiver must be the size of x. Copy b, and then 221 // copy the result into m at the end. 222 if trans { 223 if c != br { 224 panic(ErrShape) 225 } 226 dst.reuseAsNonZeroed(r, bc) 227 } else { 228 if r != br { 229 panic(ErrShape) 230 } 231 dst.reuseAsNonZeroed(c, bc) 232 } 233 // Do not need to worry about overlap between m and b because x has its own 234 // independent storage. 235 w := getDenseWorkspace(max(r, c), bc, false) 236 w.Copy(b) 237 t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat 238 if trans { 239 ok := lapack64.Trtrs(blas.Trans, t, w.mat) 240 if !ok { 241 return Condition(math.Inf(1)) 242 } 243 for i := c; i < r; i++ { 244 zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc]) 245 } 246 work := []float64{0} 247 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1) 248 work = getFloat64s(int(work[0]), false) 249 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 250 putFloat64s(work) 251 } else { 252 work := []float64{0} 253 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1) 254 work = getFloat64s(int(work[0]), false) 255 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 256 putFloat64s(work) 257 258 ok := lapack64.Trtrs(blas.NoTrans, t, w.mat) 259 if !ok { 260 return Condition(math.Inf(1)) 261 } 262 } 263 // X was set above to be the correct size for the result. 264 dst.Copy(w) 265 putDenseWorkspace(w) 266 if qr.cond > ConditionTolerance { 267 return Condition(qr.cond) 268 } 269 return nil 270 } 271 272 // SolveVecTo finds a minimum-norm solution to a system of linear equations, 273 // 274 // Ax = b. 275 // 276 // See QR.SolveTo for the full documentation. 277 // SolveVecTo will panic if the receiver does not contain a factorization. 278 func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 279 if !qr.isValid() { 280 panic(badQR) 281 } 282 283 r, c := qr.qr.Dims() 284 if _, bc := b.Dims(); bc != 1 { 285 panic(ErrShape) 286 } 287 288 // The Solve implementation is non-trivial, so rather than duplicate the code, 289 // instead recast the VecDenses as Dense and call the matrix code. 290 bm := Matrix(b) 291 if rv, ok := b.(RawVectorer); ok { 292 bmat := rv.RawVector() 293 if dst != b { 294 dst.checkOverlap(bmat) 295 } 296 b := VecDense{mat: bmat} 297 bm = b.asDense() 298 } 299 if trans { 300 dst.reuseAsNonZeroed(r) 301 } else { 302 dst.reuseAsNonZeroed(c) 303 } 304 return qr.SolveTo(dst.asDense(), trans, bm) 305 }