gonum.org/v1/gonum@v0.14.0/mat/solve.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  	"gonum.org/v1/gonum/lapack/lapack64"
    11  )
    12  
    13  // Solve solves the linear least squares problem
    14  //
    15  //	minimize over x |b - A*x|_2
    16  //
    17  // where A is an m×n matrix, b is a given m element vector and x is n element
    18  // solution vector. Solve assumes that A has full rank, that is
    19  //
    20  //	rank(A) = min(m,n)
    21  //
    22  // If m >= n, Solve finds the unique least squares solution of an overdetermined
    23  // system.
    24  //
    25  // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
    26  // this case Solve finds the unique solution of an underdetermined system that
    27  // minimizes |x|_2.
    28  //
    29  // Several right-hand side vectors b and solution vectors x can be handled in a
    30  // single call. Vectors b are stored in the columns of the m×k matrix B. Vectors
    31  // x will be stored in-place into the n×k receiver.
    32  //
    33  // If A does not have full rank, a Condition error is returned. See the
    34  // documentation for Condition for more information.
    35  func (m *Dense) Solve(a, b Matrix) error {
    36  	ar, ac := a.Dims()
    37  	br, bc := b.Dims()
    38  	if ar != br {
    39  		panic(ErrShape)
    40  	}
    41  	m.reuseAsNonZeroed(ac, bc)
    42  
    43  	// TODO(btracey): Add special cases for SymDense, etc.
    44  	aU, aTrans := untranspose(a)
    45  	bU, bTrans := untranspose(b)
    46  	switch rma := aU.(type) {
    47  	case RawTriangular:
    48  		side := blas.Left
    49  		tA := blas.NoTrans
    50  		if aTrans {
    51  			tA = blas.Trans
    52  		}
    53  
    54  		switch rm := bU.(type) {
    55  		case RawMatrixer:
    56  			if m != bU || bTrans {
    57  				if m == bU || m.checkOverlap(rm.RawMatrix()) {
    58  					tmp := getDenseWorkspace(br, bc, false)
    59  					tmp.Copy(b)
    60  					m.Copy(tmp)
    61  					putDenseWorkspace(tmp)
    62  					break
    63  				}
    64  				m.Copy(b)
    65  			}
    66  		default:
    67  			if m != bU {
    68  				m.Copy(b)
    69  			} else if bTrans {
    70  				// m and b share data so Copy cannot be used directly.
    71  				tmp := getDenseWorkspace(br, bc, false)
    72  				tmp.Copy(b)
    73  				m.Copy(tmp)
    74  				putDenseWorkspace(tmp)
    75  			}
    76  		}
    77  
    78  		rm := rma.RawTriangular()
    79  		blas64.Trsm(side, tA, 1, rm, m.mat)
    80  		work := getFloat64s(3*rm.N, false)
    81  		iwork := getInts(rm.N, false)
    82  		cond := lapack64.Trcon(CondNorm, rm, work, iwork)
    83  		putFloat64s(work)
    84  		putInts(iwork)
    85  		if cond > ConditionTolerance {
    86  			return Condition(cond)
    87  		}
    88  		return nil
    89  	}
    90  
    91  	switch {
    92  	case ar == ac:
    93  		if a == b {
    94  			// x = I.
    95  			if ar == 1 {
    96  				m.mat.Data[0] = 1
    97  				return nil
    98  			}
    99  			for i := 0; i < ar; i++ {
   100  				v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
   101  				zero(v)
   102  				v[i] = 1
   103  			}
   104  			return nil
   105  		}
   106  		var lu LU
   107  		lu.Factorize(a)
   108  		return lu.SolveTo(m, false, b)
   109  	case ar > ac:
   110  		var qr QR
   111  		qr.Factorize(a)
   112  		return qr.SolveTo(m, false, b)
   113  	default:
   114  		var lq LQ
   115  		lq.Factorize(a)
   116  		return lq.SolveTo(m, false, b)
   117  	}
   118  }
   119  
   120  // SolveVec solves the linear least squares problem
   121  //
   122  //	minimize over x |b - A*x|_2
   123  //
   124  // where A is an m×n matrix, b is a given m element vector and x is n element
   125  // solution vector. Solve assumes that A has full rank, that is
   126  //
   127  //	rank(A) = min(m,n)
   128  //
   129  // If m >= n, Solve finds the unique least squares solution of an overdetermined
   130  // system.
   131  //
   132  // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
   133  // this case Solve finds the unique solution of an underdetermined system that
   134  // minimizes |x|_2.
   135  //
   136  // The solution vector x will be stored in-place into the receiver.
   137  //
   138  // If A does not have full rank, a Condition error is returned. See the
   139  // documentation for Condition for more information.
   140  func (v *VecDense) SolveVec(a Matrix, b Vector) error {
   141  	if _, bc := b.Dims(); bc != 1 {
   142  		panic(ErrShape)
   143  	}
   144  	_, c := a.Dims()
   145  
   146  	// The Solve implementation is non-trivial, so rather than duplicate the code,
   147  	// instead recast the VecDenses as Dense and call the matrix code.
   148  
   149  	if rv, ok := b.(RawVectorer); ok {
   150  		bmat := rv.RawVector()
   151  		if v != b {
   152  			v.checkOverlap(bmat)
   153  		}
   154  		v.reuseAsNonZeroed(c)
   155  		m := v.asDense()
   156  		// We conditionally create bm as m when b and v are identical
   157  		// to prevent the overlap detection code from identifying m
   158  		// and bm as overlapping but not identical.
   159  		bm := m
   160  		if v != b {
   161  			b := VecDense{mat: bmat}
   162  			bm = b.asDense()
   163  		}
   164  		return m.Solve(a, bm)
   165  	}
   166  
   167  	v.reuseAsNonZeroed(c)
   168  	m := v.asDense()
   169  	return m.Solve(a, b)
   170  }