github.com/gopherd/gonum@v0.0.4/mat/solve.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"github.com/gopherd/gonum/blas"
     9  	"github.com/gopherd/gonum/blas/blas64"
    10  	"github.com/gopherd/gonum/lapack/lapack64"
    11  )
    12  
    13  // Solve solves the linear least squares problem
    14  //  minimize over x |b - A*x|_2
    15  // where A is an m×n matrix A, b is a given m element vector and x is n element
    16  // solution vector. Solve assumes that A has full rank, that is
    17  //  rank(A) = min(m,n)
    18  //
    19  // If m >= n, Solve finds the unique least squares solution of an overdetermined
    20  // system.
    21  //
    22  // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
    23  // this case Solve finds the unique solution of an underdetermined system that
    24  // minimizes |x|_2.
    25  //
    26  // Several right-hand side vectors b and solution vectors x can be handled in a
    27  // single call. Vectors b are stored in the columns of the m×k matrix B. Vectors
    28  // x will be stored in-place into the n×k receiver.
    29  //
    30  // If A does not have full rank, a Condition error is returned. See the
    31  // documentation for Condition for more information.
    32  func (m *Dense) Solve(a, b Matrix) error {
    33  	ar, ac := a.Dims()
    34  	br, bc := b.Dims()
    35  	if ar != br {
    36  		panic(ErrShape)
    37  	}
    38  	m.reuseAsNonZeroed(ac, bc)
    39  
    40  	// TODO(btracey): Add special cases for SymDense, etc.
    41  	aU, aTrans := untranspose(a)
    42  	bU, bTrans := untranspose(b)
    43  	switch rma := aU.(type) {
    44  	case RawTriangular:
    45  		side := blas.Left
    46  		tA := blas.NoTrans
    47  		if aTrans {
    48  			tA = blas.Trans
    49  		}
    50  
    51  		switch rm := bU.(type) {
    52  		case RawMatrixer:
    53  			if m != bU || bTrans {
    54  				if m == bU || m.checkOverlap(rm.RawMatrix()) {
    55  					tmp := getDenseWorkspace(br, bc, false)
    56  					tmp.Copy(b)
    57  					m.Copy(tmp)
    58  					putDenseWorkspace(tmp)
    59  					break
    60  				}
    61  				m.Copy(b)
    62  			}
    63  		default:
    64  			if m != bU {
    65  				m.Copy(b)
    66  			} else if bTrans {
    67  				// m and b share data so Copy cannot be used directly.
    68  				tmp := getDenseWorkspace(br, bc, false)
    69  				tmp.Copy(b)
    70  				m.Copy(tmp)
    71  				putDenseWorkspace(tmp)
    72  			}
    73  		}
    74  
    75  		rm := rma.RawTriangular()
    76  		blas64.Trsm(side, tA, 1, rm, m.mat)
    77  		work := getFloat64s(3*rm.N, false)
    78  		iwork := getInts(rm.N, false)
    79  		cond := lapack64.Trcon(CondNorm, rm, work, iwork)
    80  		putFloat64s(work)
    81  		putInts(iwork)
    82  		if cond > ConditionTolerance {
    83  			return Condition(cond)
    84  		}
    85  		return nil
    86  	}
    87  
    88  	switch {
    89  	case ar == ac:
    90  		if a == b {
    91  			// x = I.
    92  			if ar == 1 {
    93  				m.mat.Data[0] = 1
    94  				return nil
    95  			}
    96  			for i := 0; i < ar; i++ {
    97  				v := m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+ac]
    98  				zero(v)
    99  				v[i] = 1
   100  			}
   101  			return nil
   102  		}
   103  		var lu LU
   104  		lu.Factorize(a)
   105  		return lu.SolveTo(m, false, b)
   106  	case ar > ac:
   107  		var qr QR
   108  		qr.Factorize(a)
   109  		return qr.SolveTo(m, false, b)
   110  	default:
   111  		var lq LQ
   112  		lq.Factorize(a)
   113  		return lq.SolveTo(m, false, b)
   114  	}
   115  }
   116  
   117  // SolveVec solves the linear least squares problem
   118  //  minimize over x |b - A*x|_2
   119  // where A is an m×n matrix A, b is a given m element vector and x is n element
   120  // solution vector. Solve assumes that A has full rank, that is
   121  //  rank(A) = min(m,n)
   122  //
   123  // If m >= n, Solve finds the unique least squares solution of an overdetermined
   124  // system.
   125  //
   126  // If m < n, there is an infinite number of solutions that satisfy b-A*x=0. In
   127  // this case Solve finds the unique solution of an underdetermined system that
   128  // minimizes |x|_2.
   129  //
   130  // The solution vector x will be stored in-place into the receiver.
   131  //
   132  // If A does not have full rank, a Condition error is returned. See the
   133  // documentation for Condition for more information.
   134  func (v *VecDense) SolveVec(a Matrix, b Vector) error {
   135  	if _, bc := b.Dims(); bc != 1 {
   136  		panic(ErrShape)
   137  	}
   138  	_, c := a.Dims()
   139  
   140  	// The Solve implementation is non-trivial, so rather than duplicate the code,
   141  	// instead recast the VecDenses as Dense and call the matrix code.
   142  
   143  	if rv, ok := b.(RawVectorer); ok {
   144  		bmat := rv.RawVector()
   145  		if v != b {
   146  			v.checkOverlap(bmat)
   147  		}
   148  		v.reuseAsNonZeroed(c)
   149  		m := v.asDense()
   150  		// We conditionally create bm as m when b and v are identical
   151  		// to prevent the overlap detection code from identifying m
   152  		// and bm as overlapping but not identical.
   153  		bm := m
   154  		if v != b {
   155  			b := VecDense{mat: bmat}
   156  			bm = b.asDense()
   157  		}
   158  		return m.Solve(a, bm)
   159  	}
   160  
   161  	v.reuseAsNonZeroed(c)
   162  	m := v.asDense()
   163  	return m.Solve(a, b)
   164  }