gonum.org/v1/gonum@v0.14.0/mat/solve.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/blas64" 10 "gonum.org/v1/gonum/lapack/lapack64" 11 ) 12 13 // Solve solves the linear least squares problem 14 // 15 // minimize over x |b - A*x|_2 16 // 17 // where A is an m×n matrix, b is a given m element vector and x is n element 18 // solution vector. Solve assumes that A has full rank, that is 19 // 20 // rank(A) = min(m,n) 21 // 22 // If m >= n, Solve finds the unique least squares solution of an overdetermined 23 // system. 24 // 25 // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In 26 // this case Solve finds the unique solution of an underdetermined system that 27 // minimizes |x|_2. 28 // 29 // Several right-hand side vectors b and solution vectors x can be handled in a 30 // single call. Vectors b are stored in the columns of the m×k matrix B. Vectors 31 // x will be stored in-place into the n×k receiver. 32 // 33 // If A does not have full rank, a Condition error is returned. See the 34 // documentation for Condition for more information. 35 func (m *Dense) Solve(a, b Matrix) error { 36 ar, ac := a.Dims() 37 br, bc := b.Dims() 38 if ar != br { 39 panic(ErrShape) 40 } 41 m.reuseAsNonZeroed(ac, bc) 42 43 // TODO(btracey): Add special cases for SymDense, etc. 44 aU, aTrans := untranspose(a) 45 bU, bTrans := untranspose(b) 46 switch rma := aU.(type) { 47 case RawTriangular: 48 side := blas.Left 49 tA := blas.NoTrans 50 if aTrans { 51 tA = blas.Trans 52 } 53 54 switch rm := bU.(type) { 55 case RawMatrixer: 56 if m != bU || bTrans { 57 if m == bU || m.checkOverlap(rm.RawMatrix()) { 58 tmp := getDenseWorkspace(br, bc, false) 59 tmp.Copy(b) 60 m.Copy(tmp) 61 putDenseWorkspace(tmp) 62 break 63 } 64 m.Copy(b) 65 } 66 default: 67 if m != bU { 68 m.Copy(b) 69 } else if bTrans { 70 // m and b share data so Copy cannot be used directly. 71 tmp := getDenseWorkspace(br, bc, false) 72 tmp.Copy(b) 73 m.Copy(tmp) 74 putDenseWorkspace(tmp) 75 } 76 } 77 78 rm := rma.RawTriangular() 79 blas64.Trsm(side, tA, 1, rm, m.mat) 80 work := getFloat64s(3*rm.N, false) 81 iwork := getInts(rm.N, false) 82 cond := lapack64.Trcon(CondNorm, rm, work, iwork) 83 putFloat64s(work) 84 putInts(iwork) 85 if cond > ConditionTolerance { 86 return Condition(cond) 87 } 88 return nil 89 } 90 91 switch { 92 case ar == ac: 93 if a == b { 94 // x = I. 95 if ar == 1 { 96 m.mat.Data[0] = 1 97 return nil 98 } 99 for i := 0; i < ar; i++ { 100 v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac] 101 zero(v) 102 v[i] = 1 103 } 104 return nil 105 } 106 var lu LU 107 lu.Factorize(a) 108 return lu.SolveTo(m, false, b) 109 case ar > ac: 110 var qr QR 111 qr.Factorize(a) 112 return qr.SolveTo(m, false, b) 113 default: 114 var lq LQ 115 lq.Factorize(a) 116 return lq.SolveTo(m, false, b) 117 } 118 } 119 120 // SolveVec solves the linear least squares problem 121 // 122 // minimize over x |b - A*x|_2 123 // 124 // where A is an m×n matrix, b is a given m element vector and x is n element 125 // solution vector. Solve assumes that A has full rank, that is 126 // 127 // rank(A) = min(m,n) 128 // 129 // If m >= n, Solve finds the unique least squares solution of an overdetermined 130 // system. 131 // 132 // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In 133 // this case Solve finds the unique solution of an underdetermined system that 134 // minimizes |x|_2. 135 // 136 // The solution vector x will be stored in-place into the receiver. 137 // 138 // If A does not have full rank, a Condition error is returned. See the 139 // documentation for Condition for more information. 140 func (v *VecDense) SolveVec(a Matrix, b Vector) error { 141 if _, bc := b.Dims(); bc != 1 { 142 panic(ErrShape) 143 } 144 _, c := a.Dims() 145 146 // The Solve implementation is non-trivial, so rather than duplicate the code, 147 // instead recast the VecDenses as Dense and call the matrix code. 148 149 if rv, ok := b.(RawVectorer); ok { 150 bmat := rv.RawVector() 151 if v != b { 152 v.checkOverlap(bmat) 153 } 154 v.reuseAsNonZeroed(c) 155 m := v.asDense() 156 // We conditionally create bm as m when b and v are identical 157 // to prevent the overlap detection code from identifying m 158 // and bm as overlapping but not identical. 159 bm := m 160 if v != b { 161 b := VecDense{mat: bmat} 162 bm = b.asDense() 163 } 164 return m.Solve(a, bm) 165 } 166 167 v.reuseAsNonZeroed(c) 168 m := v.asDense() 169 return m.Solve(a, b) 170 }