github.com/gopherd/gonum@v0.0.4/lapack/gonum/dlatdf.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas/blas64"
    11  	"github.com/gopherd/gonum/lapack"
    12  )
    13  
    14  // Dlatdf computes a contribution to the reciprocal Dif-estimate by solving
    15  //  Z * x = h - f
    16  // and choosing the vector h such that the norm of x is as large as possible.
    17  //
    18  // The n×n matrix Z is represented by its LU factorization as computed by Dgetc2
    19  // and has the form
    20  //  Z = P * L * U * Q
    21  // where P and Q are permutation matrices, L is lower triangular with unit
    22  // diagonal elements and U is upper triangular.
    23  //
    24  // job specifies the heuristic method for computing the contribution.
    25  //
    26  // If job is lapack.LocalLookAhead, all entries of h are chosen as either +1 or
    27  // -1.
    28  //
    29  // If job is lapack.NormalizedNullVector, an approximate null-vector e of Z is
    30  // computed using Dgecon and normalized. h is chosen as ±e with the sign giving
    31  // the greater value of 2-norm(x). This strategy is about 5 times as expensive
    32  // as LocalLookAhead.
    33  //
    34  // On entry, rhs holds the contribution f from earlier solved sub-systems. On
    35  // return, rhs holds the solution x.
    36  //
    37  // ipiv and jpiv contain the pivot indices as returned by Dgetc2: row i of the
    38  // matrix has been interchanged with row ipiv[i] and column j of the matrix has
    39  // been interchanged with column jpiv[j].
    40  //
    41  // n must be at most 8, ipiv and jpiv must have length n, and rhs must have
    42  // length at least n, otherwise Dlatdf will panic.
    43  //
    44  // rdsum and rdscal represent the sum of squares of computed contributions to
    45  // the Dif-estimate from earlier solved sub-systems. rdscal is the scaling
    46  // factor used to prevent overflow in rdsum. Dlatdf returns this sum of squares
    47  // updated with the contributions from the current sub-system.
    48  //
    49  // Dlatdf is an internal routine. It is exported for testing purposes.
    50  func (impl Implementation) Dlatdf(job lapack.MaximizeNormXJob, n int, z []float64, ldz int, rhs []float64, rdsum, rdscal float64, ipiv, jpiv []int) (scale, sum float64) {
    51  	switch {
    52  	case job != lapack.LocalLookAhead && job != lapack.NormalizedNullVector:
    53  		panic(badMaximizeNormXJob)
    54  	case n < 0:
    55  		panic(nLT0)
    56  	case n > 8:
    57  		panic("lapack: n > 8")
    58  	case ldz < max(1, n):
    59  		panic(badLdZ)
    60  	}
    61  
    62  	// Quick return if possible.
    63  	if n == 0 {
    64  		return
    65  	}
    66  
    67  	switch {
    68  	case len(z) < (n-1)*ldz+n:
    69  		panic(shortZ)
    70  	case len(rhs) < n:
    71  		panic(shortRHS)
    72  	case len(ipiv) != n:
    73  		panic(badLenIpiv)
    74  	case len(jpiv) != n:
    75  		panic(badLenJpiv)
    76  	}
    77  
    78  	const maxdim = 8
    79  	var (
    80  		xps   [maxdim]float64
    81  		xms   [maxdim]float64
    82  		work  [4 * maxdim]float64
    83  		iwork [maxdim]int
    84  	)
    85  	bi := blas64.Implementation()
    86  	xp := xps[:n]
    87  	xm := xms[:n]
    88  	if job == lapack.NormalizedNullVector {
    89  		// Compute approximate nullvector xm of Z.
    90  		_ = impl.Dgecon(lapack.MaxRowSum, n, z, ldz, 1, work[:], iwork[:])
    91  		// This relies on undocumented content in work[n:2*n] stored by Dgecon.
    92  		bi.Dcopy(n, work[n:], 1, xm, 1)
    93  
    94  		// Compute rhs.
    95  		impl.Dlaswp(1, xm, 1, 0, n-2, ipiv[:n-1], -1)
    96  		tmp := 1 / bi.Dnrm2(n, xm, 1)
    97  		bi.Dscal(n, tmp, xm, 1)
    98  		bi.Dcopy(n, xm, 1, xp, 1)
    99  		bi.Daxpy(n, 1, rhs, 1, xp, 1)
   100  		bi.Daxpy(n, -1.0, xm, 1, rhs, 1)
   101  		_ = impl.Dgesc2(n, z, ldz, rhs, ipiv, jpiv)
   102  		_ = impl.Dgesc2(n, z, ldz, xp, ipiv, jpiv)
   103  		if bi.Dasum(n, xp, 1) > bi.Dasum(n, rhs, 1) {
   104  			bi.Dcopy(n, xp, 1, rhs, 1)
   105  		}
   106  
   107  		// Compute and return the updated sum of squares.
   108  		return impl.Dlassq(n, rhs, 1, rdscal, rdsum)
   109  	}
   110  
   111  	// Apply permutations ipiv to rhs
   112  	impl.Dlaswp(1, rhs, 1, 0, n-2, ipiv[:n-1], 1)
   113  
   114  	// Solve for L-part choosing rhs either to +1 or -1.
   115  	pmone := -1.0
   116  	for j := 0; j < n-2; j++ {
   117  		bp := rhs[j] + 1
   118  		bm := rhs[j] - 1
   119  
   120  		// Look-ahead for L-part rhs[0:n-2] = +1 or -1, splus and sminu computed
   121  		// more efficiently than in https://doi.org/10.1109/9.29404.
   122  		splus := 1 + bi.Ddot(n-j-1, z[(j+1)*ldz+j:], ldz, z[(j+1)*ldz+j:], ldz)
   123  		sminu := bi.Ddot(n-j-1, z[(j+1)*ldz+j:], ldz, rhs[j+1:], 1)
   124  		splus *= rhs[j]
   125  		switch {
   126  		case splus > sminu:
   127  			rhs[j] = bp
   128  		case sminu > splus:
   129  			rhs[j] = bm
   130  		default:
   131  			// In this case the updating sums are equal and we can choose rsh[j]
   132  			// +1 or -1. The first time this happens we choose -1, thereafter
   133  			// +1. This is a simple way to get good estimates of matrices like
   134  			// Byers well-known example (see https://doi.org/10.1109/9.29404).
   135  			rhs[j] += pmone
   136  			pmone = 1
   137  		}
   138  
   139  		// Compute remaining rhs.
   140  		bi.Daxpy(n-j-1, -rhs[j], z[(j+1)*ldz+j:], ldz, rhs[j+1:], 1)
   141  	}
   142  
   143  	// Solve for U-part, look-ahead for rhs[n-1] = ±1. This is not done in
   144  	// Bsolve and will hopefully give us a better estimate because any
   145  	// ill-conditioning of the original matrix is transferred to U and not to L.
   146  	// U[n-1,n-1] is an approximation to sigma_min(LU).
   147  	bi.Dcopy(n-1, rhs, 1, xp, 1)
   148  	xp[n-1] = rhs[n-1] + 1
   149  	rhs[n-1] -= 1
   150  	var splus, sminu float64
   151  	for i := n - 1; i >= 0; i-- {
   152  		tmp := 1 / z[i*ldz+i]
   153  		xp[i] *= tmp
   154  		rhs[i] *= tmp
   155  		for k := i + 1; k < n; k++ {
   156  			xp[i] -= xp[k] * (z[i*ldz+k] * tmp)
   157  			rhs[i] -= rhs[k] * (z[i*ldz+k] * tmp)
   158  		}
   159  		splus += math.Abs(xp[i])
   160  		sminu += math.Abs(rhs[i])
   161  	}
   162  	if splus > sminu {
   163  		bi.Dcopy(n, xp, 1, rhs, 1)
   164  	}
   165  
   166  	// Apply the permutations jpiv to the computed solution (rhs).
   167  	impl.Dlaswp(1, rhs, 1, 0, n-2, jpiv[:n-1], -1)
   168  
   169  	// Compute and return the updated sum of squares.
   170  	return impl.Dlassq(n, rhs, 1, rdscal, rdsum)
   171  }