gonum.org/v1/gonum@v0.14.0/mat/qr.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/blas/blas64" 12 "gonum.org/v1/gonum/lapack" 13 "gonum.org/v1/gonum/lapack/lapack64" 14 ) 15 16 const badQR = "mat: invalid QR factorization" 17 18 // QR is a type for creating and using the QR factorization of a matrix. 19 type QR struct { 20 qr *Dense 21 tau []float64 22 cond float64 23 } 24 25 func (qr *QR) updateCond(norm lapack.MatrixNorm) { 26 // Since A = Q*R, and Q is orthogonal, we get for the condition number κ 27 // κ(A) := |A| |A^-1| = |Q*R| |(Q*R)^-1| = |R| |R^-1 * Qᵀ| 28 // = |R| |R^-1| = κ(R), 29 // where we used that fact that Q^-1 = Qᵀ. However, this assumes that 30 // the matrix norm is invariant under orthogonal transformations which 31 // is not the case for CondNorm. Hopefully the error is negligible: κ 32 // is only a qualitative measure anyway. 33 n := qr.qr.mat.Cols 34 work := getFloat64s(3*n, false) 35 iwork := getInts(n, false) 36 r := qr.qr.asTriDense(n, blas.NonUnit, blas.Upper) 37 v := lapack64.Trcon(norm, r.mat, work, iwork) 38 putFloat64s(work) 39 putInts(iwork) 40 qr.cond = 1 / v 41 } 42 43 // Factorize computes the QR factorization of an m×n matrix a where m >= n. The QR 44 // factorization always exists even if A is singular. 45 // 46 // The QR decomposition is a factorization of the matrix A such that A = Q * R. 47 // The matrix Q is an orthonormal m×m matrix, and R is an m×n upper triangular matrix. 48 // Q and R can be extracted using the QTo and RTo methods. 49 func (qr *QR) Factorize(a Matrix) { 50 qr.factorize(a, CondNorm) 51 } 52 53 func (qr *QR) factorize(a Matrix, norm lapack.MatrixNorm) { 54 m, n := a.Dims() 55 if m < n { 56 panic(ErrShape) 57 } 58 k := min(m, n) 59 if qr.qr == nil { 60 qr.qr = &Dense{} 61 } 62 qr.qr.CloneFrom(a) 63 work := []float64{0} 64 qr.tau = make([]float64, k) 65 lapack64.Geqrf(qr.qr.mat, qr.tau, work, -1) 66 work = getFloat64s(int(work[0]), false) 67 lapack64.Geqrf(qr.qr.mat, qr.tau, work, len(work)) 68 putFloat64s(work) 69 qr.updateCond(norm) 70 } 71 72 // isValid returns whether the receiver contains a factorization. 73 func (qr *QR) isValid() bool { 74 return qr.qr != nil && !qr.qr.IsEmpty() 75 } 76 77 // Cond returns the condition number for the factorized matrix. 78 // Cond will panic if the receiver does not contain a factorization. 79 func (qr *QR) Cond() float64 { 80 if !qr.isValid() { 81 panic(badQR) 82 } 83 return qr.cond 84 } 85 86 // TODO(btracey): Add in the "Reduced" forms for extracting the n×n orthogonal 87 // and upper triangular matrices. 88 89 // RTo extracts the m×n upper trapezoidal matrix from a QR decomposition. 90 // 91 // If dst is empty, RTo will resize dst to be r×c. When dst is non-empty, 92 // RTo will panic if dst is not r×c. RTo will also panic if the receiver 93 // does not contain a successful factorization. 94 func (qr *QR) RTo(dst *Dense) { 95 if !qr.isValid() { 96 panic(badQR) 97 } 98 99 r, c := qr.qr.Dims() 100 if dst.IsEmpty() { 101 dst.ReuseAs(r, c) 102 } else { 103 r2, c2 := dst.Dims() 104 if c != r2 || c != c2 { 105 panic(ErrShape) 106 } 107 } 108 109 // Disguise the QR as an upper triangular 110 t := &TriDense{ 111 mat: blas64.Triangular{ 112 N: c, 113 Stride: qr.qr.mat.Stride, 114 Data: qr.qr.mat.Data, 115 Uplo: blas.Upper, 116 Diag: blas.NonUnit, 117 }, 118 cap: qr.qr.capCols, 119 } 120 dst.Copy(t) 121 122 // Zero below the triangular. 123 for i := r; i < c; i++ { 124 zero(dst.mat.Data[i*dst.mat.Stride : i*dst.mat.Stride+c]) 125 } 126 } 127 128 // QTo extracts the r×r orthonormal matrix Q from a QR decomposition. 129 // 130 // If dst is empty, QTo will resize dst to be r×r. When dst is non-empty, 131 // QTo will panic if dst is not r×r. QTo will also panic if the receiver 132 // does not contain a successful factorization. 133 func (qr *QR) QTo(dst *Dense) { 134 if !qr.isValid() { 135 panic(badQR) 136 } 137 138 r, _ := qr.qr.Dims() 139 if dst.IsEmpty() { 140 dst.ReuseAs(r, r) 141 } else { 142 r2, c2 := dst.Dims() 143 if r != r2 || r != c2 { 144 panic(ErrShape) 145 } 146 dst.Zero() 147 } 148 149 // Set Q = I. 150 for i := 0; i < r*r; i += r + 1 { 151 dst.mat.Data[i] = 1 152 } 153 154 // Construct Q from the elementary reflectors. 155 work := []float64{0} 156 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, -1) 157 work = getFloat64s(int(work[0]), false) 158 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, dst.mat, work, len(work)) 159 putFloat64s(work) 160 } 161 162 // SolveTo finds a minimum-norm solution to a system of linear equations defined 163 // by the matrices A and b, where A is an m×n matrix represented in its QR factorized 164 // form. If A is singular or near-singular a Condition error is returned. 165 // See the documentation for Condition for more information. 166 // 167 // The minimization problem solved depends on the input parameters. 168 // 169 // If trans == false, find X such that ||A*X - B||_2 is minimized. 170 // If trans == true, find the minimum norm solution of Aᵀ * X = B. 171 // 172 // The solution matrix, X, is stored in place into dst. 173 // SolveTo will panic if the receiver does not contain a factorization. 174 func (qr *QR) SolveTo(dst *Dense, trans bool, b Matrix) error { 175 if !qr.isValid() { 176 panic(badQR) 177 } 178 179 r, c := qr.qr.Dims() 180 br, bc := b.Dims() 181 182 // The QR solve algorithm stores the result in-place into the right hand side. 183 // The storage for the answer must be large enough to hold both b and x. 184 // However, this method's receiver must be the size of x. Copy b, and then 185 // copy the result into m at the end. 186 if trans { 187 if c != br { 188 panic(ErrShape) 189 } 190 dst.reuseAsNonZeroed(r, bc) 191 } else { 192 if r != br { 193 panic(ErrShape) 194 } 195 dst.reuseAsNonZeroed(c, bc) 196 } 197 // Do not need to worry about overlap between m and b because x has its own 198 // independent storage. 199 w := getDenseWorkspace(max(r, c), bc, false) 200 w.Copy(b) 201 t := qr.qr.asTriDense(qr.qr.mat.Cols, blas.NonUnit, blas.Upper).mat 202 if trans { 203 ok := lapack64.Trtrs(blas.Trans, t, w.mat) 204 if !ok { 205 return Condition(math.Inf(1)) 206 } 207 for i := c; i < r; i++ { 208 zero(w.mat.Data[i*w.mat.Stride : i*w.mat.Stride+bc]) 209 } 210 work := []float64{0} 211 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, -1) 212 work = getFloat64s(int(work[0]), false) 213 lapack64.Ormqr(blas.Left, blas.NoTrans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 214 putFloat64s(work) 215 } else { 216 work := []float64{0} 217 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, -1) 218 work = getFloat64s(int(work[0]), false) 219 lapack64.Ormqr(blas.Left, blas.Trans, qr.qr.mat, qr.tau, w.mat, work, len(work)) 220 putFloat64s(work) 221 222 ok := lapack64.Trtrs(blas.NoTrans, t, w.mat) 223 if !ok { 224 return Condition(math.Inf(1)) 225 } 226 } 227 // X was set above to be the correct size for the result. 228 dst.Copy(w) 229 putDenseWorkspace(w) 230 if qr.cond > ConditionTolerance { 231 return Condition(qr.cond) 232 } 233 return nil 234 } 235 236 // SolveVecTo finds a minimum-norm solution to a system of linear equations, 237 // 238 // Ax = b. 239 // 240 // See QR.SolveTo for the full documentation. 241 // SolveVecTo will panic if the receiver does not contain a factorization. 242 func (qr *QR) SolveVecTo(dst *VecDense, trans bool, b Vector) error { 243 if !qr.isValid() { 244 panic(badQR) 245 } 246 247 r, c := qr.qr.Dims() 248 if _, bc := b.Dims(); bc != 1 { 249 panic(ErrShape) 250 } 251 252 // The Solve implementation is non-trivial, so rather than duplicate the code, 253 // instead recast the VecDenses as Dense and call the matrix code. 254 bm := Matrix(b) 255 if rv, ok := b.(RawVectorer); ok { 256 bmat := rv.RawVector() 257 if dst != b { 258 dst.checkOverlap(bmat) 259 } 260 b := VecDense{mat: bmat} 261 bm = b.asDense() 262 } 263 if trans { 264 dst.reuseAsNonZeroed(r) 265 } else { 266 dst.reuseAsNonZeroed(c) 267 } 268 return qr.SolveTo(dst.asDense(), trans, bm) 269 }