github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/qr.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "github.com/jingcheng-WU/gonum/blas" 11 "github.com/jingcheng-WU/gonum/blas/blas64" 12 "github.com/jingcheng-WU/gonum/lapack" 13 "github.com/jingcheng-WU/gonum/lapack/lapack64" 14 ) 15 16 const badQR = "mat: invalid QR factorization" 17 18 // QR is a type for creating and using the QR factorization of a matrix. 19 type QR struct { 20 qr *Dense 21 tau []float64 22 cond float64 23 } 24 25 func (qr *QR) updateCond(norm lapack.MatrixNorm) { 26 // Since A = Q*R, and Q is orthogonal, we get for the condition number κ 27 // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ| 28 // = |R| |R^-1| = κ(R), 29 // where we used that fact that Q^-1 = Qᵀ. However, this assumes that 30 // the matrix norm is invariant under orthogonal transformations which 31 // is not the case for CondNorm. Hopefully the error is negligible: κ 32 // is only a qualitative measure anyway. 33 n := qr.qr.mat.Cols 34 work := getFloats(3*n, false) 35 iwork := getInts(n, false) 36 r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper) 37 v := lapack64.Trcon(norm, r.mat, work, iwork) 38 putFloats(work) 39 putInts(iwork) 40 qr.cond = 1 / v 41 } 42 43 // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR 44 // factorization always exists even if A is singular. 45 // 46 // The QR decomposition is a factorization of the matrix A such that A = Q * R. 47 // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix. 48 // Q and R can be extracted using the QTo and RTo methods. 49 func (qr *QR) Factorize(a Matrix) { 50 qr.factorize(a, CondNorm) 51 } 52 53 func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) { 54 m, n := a.Dims() 55 if m < n { 56 panic(ErrShape) 57 } 58 k := min(m, n) 59 if qr.qr == nil { 60 qr.qr = &Dense{} 61 } 62 qr.qr.CloneFrom(a) 63 work := []float64{0} 64 qr.tau = make([]float64, k) 65 lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1) 66 work = getFloats(int(work[0]), false) 67 lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work)) 68 putFloats(work) 69 qr.updateCond(norm) 70 } 71 72 // isValid returns whether the receiver contains a factorization. 73 func (qr *QR) isValid() bool { 74 return qr.qr != nil && !qr.qr.IsEmpty() 75 } 76 77 // Cond returns the condition number for the factorized matrix. 78 // Cond will panic if the receiver does not contain a factorization. 79 func (qr *QR) Cond() float64 { 80 if !qr.isValid() { 81 panic(badQR) 82 } 83 return qr.cond 84 } 85 86 // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal 87 // and upper triangular matrices. 88 89 // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition. 90 // 91 // If dst is empty, RTo will resize dst to be r×c. When dst is non-empty, 92 // RTo will panic if dst is not r×c. RTo will also panic if the receiver 93 // does not contain a successful factorization. 94 func (qr *QR) RTo(dst *Dense) { 95 if !qr.isValid() { 96 panic(badQR) 97 } 98 99 r, c := qr.qr.Dims() 100 if dst.IsEmpty() { 101 dst.ReuseAs(r, c) 102 } else { 103 r2, c2 := dst.Dims() 104 if c != r2 || c != c2 { 105 panic(ErrShape) 106 } 107 } 108 109 // Disguise the QR as an upper triangular 110 t := &TriDense{ 111 mat: blas64.Triangular{ 112 N: c, 113 Stride: qr.qr.mat.Stride, 114 Data: qr.qr.mat.Data, 115 Uplo: blas.Upper, 116 Diag: blas.NonUnit, 117 }, 118 cap: qr.qr.capCols, 119 } 120 dst.Copy(t) 121 122 // Zero below the triangular. 123 for i := r; i < c; i++ { 124 zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c]) 125 } 126 } 127 128 // QTo extracts the r×r orthonormal matrix Q from a QR decomposition. 129 // 130 // If dst is empty, QTo will resize dst to be r×r. When dst is non-empty, 131 // QTo will panic if dst is not r×r. QTo will also panic if the receiver 132 // does not contain a successful factorization. 133 func (qr *QR) QTo(dst *Dense) { 134 if !qr.isValid() { 135 panic(badQR) 136 } 137 138 r, _ := qr.qr.Dims() 139 if dst.IsEmpty() { 140 dst.ReuseAs(r, r) 141 } else { 142 r2, c2 := dst.Dims() 143 if r != r2 || r != c2 { 144 panic(ErrShape) 145 } 146 dst.Zero() 147 } 148 149 // Set Q = I. 150 for i := 0; i < r*r; i += r + 1 { 151 dst.mat.Data[i] = 1 152 } 153 154 // Construct Q from the elementary reflectors. 155 work := []float64{0} 156 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1) 157 work = getFloats(int(work[0]), false) 158 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work)) 159 putFloats(work) 160 } 161 162 // SolveTo finds a minimum-norm solution to a system of linear equations defined 163 // by the matrices A and b, where A is an m×n matrix represented in its QR factorized 164 // form. If A is singular or near-singular a Condition error is returned. 165 // See the documentation for Condition for more information. 166 // 167 // The minimization problem solved depends on the input parameters. 168 // If trans == false, find X such that ||A*X - B||_2 is minimized. 169 // If trans == true, find the minimum norm solution of Aᵀ * X = B. 170 // The solution matrix, X, is stored in place into dst. 171 // SolveTo will panic if the receiver does not contain a factorization. 172 func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error { 173 if !qr.isValid() { 174 panic(badQR) 175 } 176 177 r, c := qr.qr.Dims() 178 br, bc := b.Dims() 179 180 // The QR solve algorithm stores the result in-place into the right hand side. 181 // The storage for the answer must be large enough to hold both b and x. 182 // However, this method's receiver must be the size of x. Copy b, and then 183 // copy the result into m at the end. 184 if trans { 185 if c != br { 186 panic(ErrShape) 187 } 188 dst.reuseAsNonZeroed(r, bc) 189 } else { 190 if r != br { 191 panic(ErrShape) 192 } 193 dst.reuseAsNonZeroed(c, bc) 194 } 195 // Do not need to worry about overlap between m and b because x has its own 196 // independent storage. 197 w := getWorkspace(max(r, c), bc, false) 198 w.Copy(b) 199 t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat 200 if trans { 201 ok := lapack64.Trtrs(blas.Trans, t, w.mat) 202 if !ok { 203 return Condition(math.Inf(1)) 204 } 205 for i := c; i < r; i++ { 206 zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc]) 207 } 208 work := []float64{0} 209 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1) 210 work = getFloats(int(work[0]), false) 211 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 212 putFloats(work) 213 } else { 214 work := []float64{0} 215 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1) 216 work = getFloats(int(work[0]), false) 217 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 218 putFloats(work) 219 220 ok := lapack64.Trtrs(blas.NoTrans, t, w.mat) 221 if !ok { 222 return Condition(math.Inf(1)) 223 } 224 } 225 // X was set above to be the correct size for the result. 226 dst.Copy(w) 227 putWorkspace(w) 228 if qr.cond > ConditionTolerance { 229 return Condition(qr.cond) 230 } 231 return nil 232 } 233 234 // SolveVecTo finds a minimum-norm solution to a system of linear equations, 235 // Ax = b. 236 // See QR.SolveTo for the full documentation. 237 // SolveVecTo will panic if the receiver does not contain a factorization. 238 func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 239 if !qr.isValid() { 240 panic(badQR) 241 } 242 243 r, c := qr.qr.Dims() 244 if _, bc := b.Dims(); bc != 1 { 245 panic(ErrShape) 246 } 247 248 // The Solve implementation is non-trivial, so rather than duplicate the code, 249 // instead recast the VecDenses as Dense and call the matrix code. 250 bm := Matrix(b) 251 if rv, ok := b.(RawVectorer); ok { 252 bmat := rv.RawVector() 253 if dst != b { 254 dst.checkOverlap(bmat) 255 } 256 b := VecDense{mat: bmat} 257 bm = b.asDense() 258 } 259 if trans { 260 dst.reuseAsNonZeroed(r) 261 } else { 262 dst.reuseAsNonZeroed(c) 263 } 264 return qr.SolveTo(dst.asDense(), trans, bm) 265 266 }