github.com/gopherd/gonum@v0.0.4/lapack/gonum/dlatrs.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  )
    13  
    14  // Dlatrs solves a triangular system of equations scaled to prevent overflow. It
    15  // solves
    16  //  A * x = scale * b if trans == blas.NoTrans
    17  //  Aᵀ * x = scale * b if trans == blas.Trans
    18  // where the scale s is set for numeric stability.
    19  //
    20  // A is an n×n triangular matrix. On entry, the slice x contains the values of
    21  // b, and on exit it contains the solution vector x.
    22  //
    23  // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal
    24  // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater
    25  // than or equal to the infinity norm, and greater than or equal to the one-norm
    26  // otherwise. If normin == false, then cnorm is treated as an output, and is set
    27  // to contain the 1-norm of the off-diagonal part of the j^th column of A.
    28  //
    29  // Dlatrs is an internal routine. It is exported for testing purposes.
    30  func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) {
    31  	switch {
    32  	case uplo != blas.Upper && uplo != blas.Lower:
    33  		panic(badUplo)
    34  	case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
    35  		panic(badTrans)
    36  	case diag != blas.Unit && diag != blas.NonUnit:
    37  		panic(badDiag)
    38  	case n < 0:
    39  		panic(nLT0)
    40  	case lda < max(1, n):
    41  		panic(badLdA)
    42  	}
    43  
    44  	// Quick return if possible.
    45  	if n == 0 {
    46  		return 0
    47  	}
    48  
    49  	switch {
    50  	case len(a) < (n-1)*lda+n:
    51  		panic(shortA)
    52  	case len(x) < n:
    53  		panic(shortX)
    54  	case len(cnorm) < n:
    55  		panic(shortCNorm)
    56  	}
    57  
    58  	upper := uplo == blas.Upper
    59  	nonUnit := diag == blas.NonUnit
    60  
    61  	smlnum := dlamchS / dlamchP
    62  	bignum := 1 / smlnum
    63  	scale = 1
    64  
    65  	bi := blas64.Implementation()
    66  
    67  	if !normin {
    68  		if upper {
    69  			cnorm[0] = 0
    70  			for j := 1; j < n; j++ {
    71  				cnorm[j] = bi.Dasum(j, a[j:], lda)
    72  			}
    73  		} else {
    74  			for j := 0; j < n-1; j++ {
    75  				cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda)
    76  			}
    77  			cnorm[n-1] = 0
    78  		}
    79  	}
    80  	// Scale the column norms by tscal if the maximum element in cnorm is greater than bignum.
    81  	imax := bi.Idamax(n, cnorm, 1)
    82  	tmax := cnorm[imax]
    83  	var tscal float64
    84  	if tmax <= bignum {
    85  		tscal = 1
    86  	} else {
    87  		tscal = 1 / (smlnum * tmax)
    88  		bi.Dscal(n, tscal, cnorm, 1)
    89  	}
    90  
    91  	// Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
    92  	j := bi.Idamax(n, x, 1)
    93  	xmax := math.Abs(x[j])
    94  	xbnd := xmax
    95  	var grow float64
    96  	var jfirst, jlast, jinc int
    97  	if trans == blas.NoTrans {
    98  		if upper {
    99  			jfirst = n - 1
   100  			jlast = -1
   101  			jinc = -1
   102  		} else {
   103  			jfirst = 0
   104  			jlast = n
   105  			jinc = 1
   106  		}
   107  		// Compute the growth in A * x = b.
   108  		if tscal != 1 {
   109  			grow = 0
   110  			goto Solve
   111  		}
   112  		if nonUnit {
   113  			grow = 1 / math.Max(xbnd, smlnum)
   114  			xbnd = grow
   115  			for j := jfirst; j != jlast; j += jinc {
   116  				if grow <= smlnum {
   117  					goto Solve
   118  				}
   119  				tjj := math.Abs(a[j*lda+j])
   120  				xbnd = math.Min(xbnd, math.Min(1, tjj)*grow)
   121  				if tjj+cnorm[j] >= smlnum {
   122  					grow *= tjj / (tjj + cnorm[j])
   123  				} else {
   124  					grow = 0
   125  				}
   126  			}
   127  			grow = xbnd
   128  		} else {
   129  			grow = math.Min(1, 1/math.Max(xbnd, smlnum))
   130  			for j := jfirst; j != jlast; j += jinc {
   131  				if grow <= smlnum {
   132  					goto Solve
   133  				}
   134  				grow *= 1 / (1 + cnorm[j])
   135  			}
   136  		}
   137  	} else {
   138  		if upper {
   139  			jfirst = 0
   140  			jlast = n
   141  			jinc = 1
   142  		} else {
   143  			jfirst = n - 1
   144  			jlast = -1
   145  			jinc = -1
   146  		}
   147  		if tscal != 1 {
   148  			grow = 0
   149  			goto Solve
   150  		}
   151  		if nonUnit {
   152  			grow = 1 / (math.Max(xbnd, smlnum))
   153  			xbnd = grow
   154  			for j := jfirst; j != jlast; j += jinc {
   155  				if grow <= smlnum {
   156  					goto Solve
   157  				}
   158  				xj := 1 + cnorm[j]
   159  				grow = math.Min(grow, xbnd/xj)
   160  				tjj := math.Abs(a[j*lda+j])
   161  				if xj > tjj {
   162  					xbnd *= tjj / xj
   163  				}
   164  			}
   165  			grow = math.Min(grow, xbnd)
   166  		} else {
   167  			grow = math.Min(1, 1/math.Max(xbnd, smlnum))
   168  			for j := jfirst; j != jlast; j += jinc {
   169  				if grow <= smlnum {
   170  					goto Solve
   171  				}
   172  				xj := 1 + cnorm[j]
   173  				grow /= xj
   174  			}
   175  		}
   176  	}
   177  
   178  Solve:
   179  	if grow*tscal > smlnum {
   180  		// Use the Level 2 BLAS solve if the reciprocal of the bound on
   181  		// elements of X is not too small.
   182  		bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
   183  		if tscal != 1 {
   184  			bi.Dscal(n, 1/tscal, cnorm, 1)
   185  		}
   186  		return scale
   187  	}
   188  
   189  	// Use a Level 1 BLAS solve, scaling intermediate results.
   190  	if xmax > bignum {
   191  		scale = bignum / xmax
   192  		bi.Dscal(n, scale, x, 1)
   193  		xmax = bignum
   194  	}
   195  	if trans == blas.NoTrans {
   196  		for j := jfirst; j != jlast; j += jinc {
   197  			xj := math.Abs(x[j])
   198  			var tjj, tjjs float64
   199  			if nonUnit {
   200  				tjjs = a[j*lda+j] * tscal
   201  			} else {
   202  				tjjs = tscal
   203  				if tscal == 1 {
   204  					goto Skip1
   205  				}
   206  			}
   207  			tjj = math.Abs(tjjs)
   208  			if tjj > smlnum {
   209  				if tjj < 1 {
   210  					if xj > tjj*bignum {
   211  						rec := 1 / xj
   212  						bi.Dscal(n, rec, x, 1)
   213  						scale *= rec
   214  						xmax *= rec
   215  					}
   216  				}
   217  				x[j] /= tjjs
   218  				xj = math.Abs(x[j])
   219  			} else if tjj > 0 {
   220  				if xj > tjj*bignum {
   221  					rec := (tjj * bignum) / xj
   222  					if cnorm[j] > 1 {
   223  						rec /= cnorm[j]
   224  					}
   225  					bi.Dscal(n, rec, x, 1)
   226  					scale *= rec
   227  					xmax *= rec
   228  				}
   229  				x[j] /= tjjs
   230  				xj = math.Abs(x[j])
   231  			} else {
   232  				for i := 0; i < n; i++ {
   233  					x[i] = 0
   234  				}
   235  				x[j] = 1
   236  				xj = 1
   237  				scale = 0
   238  				xmax = 0
   239  			}
   240  		Skip1:
   241  			if xj > 1 {
   242  				rec := 1 / xj
   243  				if cnorm[j] > (bignum-xmax)*rec {
   244  					rec *= 0.5
   245  					bi.Dscal(n, rec, x, 1)
   246  					scale *= rec
   247  				}
   248  			} else if xj*cnorm[j] > bignum-xmax {
   249  				bi.Dscal(n, 0.5, x, 1)
   250  				scale *= 0.5
   251  			}
   252  			if upper {
   253  				if j > 0 {
   254  					bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1)
   255  					i := bi.Idamax(j, x, 1)
   256  					xmax = math.Abs(x[i])
   257  				}
   258  			} else {
   259  				if j < n-1 {
   260  					bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1)
   261  					i := j + bi.Idamax(n-j-1, x[j+1:], 1)
   262  					xmax = math.Abs(x[i])
   263  				}
   264  			}
   265  		}
   266  	} else {
   267  		for j := jfirst; j != jlast; j += jinc {
   268  			xj := math.Abs(x[j])
   269  			uscal := tscal
   270  			rec := 1 / math.Max(xmax, 1)
   271  			var tjjs float64
   272  			if cnorm[j] > (bignum-xj)*rec {
   273  				rec *= 0.5
   274  				if nonUnit {
   275  					tjjs = a[j*lda+j] * tscal
   276  				} else {
   277  					tjjs = tscal
   278  				}
   279  				tjj := math.Abs(tjjs)
   280  				if tjj > 1 {
   281  					rec = math.Min(1, rec*tjj)
   282  					uscal /= tjjs
   283  				}
   284  				if rec < 1 {
   285  					bi.Dscal(n, rec, x, 1)
   286  					scale *= rec
   287  					xmax *= rec
   288  				}
   289  			}
   290  			var sumj float64
   291  			if uscal == 1 {
   292  				if upper {
   293  					sumj = bi.Ddot(j, a[j:], lda, x, 1)
   294  				} else if j < n-1 {
   295  					sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1)
   296  				}
   297  			} else {
   298  				if upper {
   299  					for i := 0; i < j; i++ {
   300  						sumj += (a[i*lda+j] * uscal) * x[i]
   301  					}
   302  				} else if j < n {
   303  					for i := j + 1; i < n; i++ {
   304  						sumj += (a[i*lda+j] * uscal) * x[i]
   305  					}
   306  				}
   307  			}
   308  			if uscal == tscal {
   309  				x[j] -= sumj
   310  				xj := math.Abs(x[j])
   311  				var tjjs float64
   312  				if nonUnit {
   313  					tjjs = a[j*lda+j] * tscal
   314  				} else {
   315  					tjjs = tscal
   316  					if tscal == 1 {
   317  						goto Skip2
   318  					}
   319  				}
   320  				tjj := math.Abs(tjjs)
   321  				if tjj > smlnum {
   322  					if tjj < 1 {
   323  						if xj > tjj*bignum {
   324  							rec = 1 / xj
   325  							bi.Dscal(n, rec, x, 1)
   326  							scale *= rec
   327  							xmax *= rec
   328  						}
   329  					}
   330  					x[j] /= tjjs
   331  				} else if tjj > 0 {
   332  					if xj > tjj*bignum {
   333  						rec = (tjj * bignum) / xj
   334  						bi.Dscal(n, rec, x, 1)
   335  						scale *= rec
   336  						xmax *= rec
   337  					}
   338  					x[j] /= tjjs
   339  				} else {
   340  					for i := 0; i < n; i++ {
   341  						x[i] = 0
   342  					}
   343  					x[j] = 1
   344  					scale = 0
   345  					xmax = 0
   346  				}
   347  			} else {
   348  				x[j] = x[j]/tjjs - sumj
   349  			}
   350  		Skip2:
   351  			xmax = math.Max(xmax, math.Abs(x[j]))
   352  		}
   353  	}
   354  	scale /= tscal
   355  	if tscal != 1 {
   356  		bi.Dscal(n, 1/tscal, cnorm, 1)
   357  	}
   358  	return scale
   359  }