github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/solve.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "github.com/jingcheng-WU/gonum/blas" 9 "github.com/jingcheng-WU/gonum/blas/blas64" 10 "github.com/jingcheng-WU/gonum/lapack/lapack64" 11 ) 12 13 // Solve solves the linear least squares problem 14 // minimize over x |b - A*x|_2 15 // where A is an m×n matrix A, b is a given m element vector and x is n element 16 // solution vector. Solve assumes that A has full rank, that is 17 // rank(A) = min(m,n) 18 // 19 // If m >= n, Solve finds the unique least squares solution of an overdetermined 20 // system. 21 // 22 // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In 23 // this case Solve finds the unique solution of an underdetermined system that 24 // minimizes |x|_2. 25 // 26 // Several right-hand side vectors b and solution vectors x can be handled in a 27 // single call. Vectors b are stored in the columns of the m×k matrix B. Vectors 28 // x will be stored in-place into the n×k receiver. 29 // 30 // If A does not have full rank, a Condition error is returned. See the 31 // documentation for Condition for more information. 32 func (m *Dense) Solve(a, b Matrix) error { 33 ar, ac := a.Dims() 34 br, bc := b.Dims() 35 if ar != br { 36 panic(ErrShape) 37 } 38 m.reuseAsNonZeroed(ac, bc) 39 40 // TODO(btracey): Add special cases for SymDense, etc. 41 aU, aTrans := untranspose(a) 42 bU, bTrans := untranspose(b) 43 switch rma := aU.(type) { 44 case RawTriangular: 45 side := blas.Left 46 tA := blas.NoTrans 47 if aTrans { 48 tA = blas.Trans 49 } 50 51 switch rm := bU.(type) { 52 case RawMatrixer: 53 if m != bU || bTrans { 54 if m == bU || m.checkOverlap(rm.RawMatrix()) { 55 tmp := getWorkspace(br, bc, false) 56 tmp.Copy(b) 57 m.Copy(tmp) 58 putWorkspace(tmp) 59 break 60 } 61 m.Copy(b) 62 } 63 default: 64 if m != bU { 65 m.Copy(b) 66 } else if bTrans { 67 // m and b share data so Copy cannot be used directly. 68 tmp := getWorkspace(br, bc, false) 69 tmp.Copy(b) 70 m.Copy(tmp) 71 putWorkspace(tmp) 72 } 73 } 74 75 rm := rma.RawTriangular() 76 blas64.Trsm(side, tA, 1, rm, m.mat) 77 work := getFloats(3*rm.N, false) 78 iwork := getInts(rm.N, false) 79 cond := lapack64.Trcon(CondNorm, rm, work, iwork) 80 putFloats(work) 81 putInts(iwork) 82 if cond > ConditionTolerance { 83 return Condition(cond) 84 } 85 return nil 86 } 87 88 switch { 89 case ar == ac: 90 if a == b { 91 // x = I. 92 if ar == 1 { 93 m.mat.Data[0] = 1 94 return nil 95 } 96 for i := 0; i < ar; i++ { 97 v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac] 98 zero(v) 99 v[i] = 1 100 } 101 return nil 102 } 103 var lu LU 104 lu.Factorize(a) 105 return lu.SolveTo(m, false, b) 106 case ar > ac: 107 var qr QR 108 qr.Factorize(a) 109 return qr.SolveTo(m, false, b) 110 default: 111 var lq LQ 112 lq.Factorize(a) 113 return lq.SolveTo(m, false, b) 114 } 115 } 116 117 // SolveVec solves the linear least squares problem 118 // minimize over x |b - A*x|_2 119 // where A is an m×n matrix A, b is a given m element vector and x is n element 120 // solution vector. Solve assumes that A has full rank, that is 121 // rank(A) = min(m,n) 122 // 123 // If m >= n, Solve finds the unique least squares solution of an overdetermined 124 // system. 125 // 126 // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In 127 // this case Solve finds the unique solution of an underdetermined system that 128 // minimizes |x|_2. 129 // 130 // The solution vector x will be stored in-place into the receiver. 131 // 132 // If A does not have full rank, a Condition error is returned. See the 133 // documentation for Condition for more information. 134 func (v *VecDense) SolveVec(a Matrix, b Vector) error { 135 if _, bc := b.Dims(); bc != 1 { 136 panic(ErrShape) 137 } 138 _, c := a.Dims() 139 140 // The Solve implementation is non-trivial, so rather than duplicate the code, 141 // instead recast the VecDenses as Dense and call the matrix code. 142 143 if rv, ok := b.(RawVectorer); ok { 144 bmat := rv.RawVector() 145 if v != b { 146 v.checkOverlap(bmat) 147 } 148 v.reuseAsNonZeroed(c) 149 m := v.asDense() 150 // We conditionally create bm as m when b and v are identical 151 // to prevent the overlap detection code from identifying m 152 // and bm as overlapping but not identical. 153 bm := m 154 if v != b { 155 b := VecDense{mat: bmat} 156 bm = b.asDense() 157 } 158 return m.Solve(a, bm) 159 } 160 161 v.reuseAsNonZeroed(c) 162 m := v.asDense() 163 return m.Solve(a, b) 164 }