gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlatrs.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  	"gonum.org/v1/gonum/lapack"
    13  )
    14  
    15  // Dlatrs solves a triangular system of equations scaled to prevent overflow. It
    16  // solves
    17  //
    18  //	A * x = scale * b if trans == blas.NoTrans
    19  //	Aᵀ * x = scale * b if trans == blas.Trans
    20  //
    21  // where the scale s is set for numeric stability.
    22  //
    23  // A is an n×n triangular matrix. On entry, the slice x contains the values of
    24  // b, and on exit it contains the solution vector x.
    25  //
    26  // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal
    27  // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater
    28  // than or equal to the infinity norm, and greater than or equal to the one-norm
    29  // otherwise. If normin == false, then cnorm is treated as an output, and is set
    30  // to contain the 1-norm of the off-diagonal part of the j^th column of A.
    31  //
    32  // Dlatrs is an internal routine. It is exported for testing purposes.
    33  func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) {
    34  	switch {
    35  	case uplo != blas.Upper && uplo != blas.Lower:
    36  		panic(badUplo)
    37  	case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
    38  		panic(badTrans)
    39  	case diag != blas.Unit && diag != blas.NonUnit:
    40  		panic(badDiag)
    41  	case n < 0:
    42  		panic(nLT0)
    43  	case lda < max(1, n):
    44  		panic(badLdA)
    45  	}
    46  
    47  	// Quick return if possible.
    48  	if n == 0 {
    49  		return 1
    50  	}
    51  
    52  	switch {
    53  	case len(a) < (n-1)*lda+n:
    54  		panic(shortA)
    55  	case len(x) < n:
    56  		panic(shortX)
    57  	case len(cnorm) < n:
    58  		panic(shortCNorm)
    59  	}
    60  
    61  	upper := uplo == blas.Upper
    62  	nonUnit := diag == blas.NonUnit
    63  
    64  	smlnum := dlamchS / dlamchP
    65  	bignum := 1 / smlnum
    66  	scale = 1
    67  
    68  	bi := blas64.Implementation()
    69  
    70  	if !normin {
    71  		if upper {
    72  			cnorm[0] = 0
    73  			for j := 1; j < n; j++ {
    74  				cnorm[j] = bi.Dasum(j, a[j:], lda)
    75  			}
    76  		} else {
    77  			for j := 0; j < n-1; j++ {
    78  				cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda)
    79  			}
    80  			cnorm[n-1] = 0
    81  		}
    82  	}
    83  	// Scale the column norms by tscal if the maximum element in cnorm is greater than bignum.
    84  	imax := bi.Idamax(n, cnorm, 1)
    85  	var tscal float64
    86  	if cnorm[imax] <= bignum {
    87  		tscal = 1
    88  	} else {
    89  		tmax := cnorm[imax]
    90  		// Avoid NaN generation if entries in cnorm exceed the overflow
    91  		// threshold.
    92  		if tmax <= math.MaxFloat64 {
    93  			// Case 1: All entries in cnorm are valid floating-point numbers.
    94  			tscal = 1 / (smlnum * tmax)
    95  			bi.Dscal(n, tscal, cnorm, 1)
    96  		} else {
    97  			// Case 2: At least one column norm of A cannot be represented as
    98  			// floating-point number. Find the offdiagonal entry A[i,j] with the
    99  			// largest absolute value. If this entry is not +/- Infinity, use
   100  			// this value as tscal.
   101  			tmax = 0
   102  			if upper {
   103  				// A is upper triangular.
   104  				for j := 1; j < n; j++ {
   105  					tmax = math.Max(impl.Dlange(lapack.MaxAbs, j, 1, a[j:], lda, nil), tmax)
   106  				}
   107  			} else {
   108  				// A is lower triangular.
   109  				for j := 0; j < n-1; j++ {
   110  					tmax = math.Max(impl.Dlange(lapack.MaxAbs, n-j-1, 1, a[(j+1)*lda+j:], lda, nil), tmax)
   111  				}
   112  			}
   113  			if tmax <= math.MaxFloat64 {
   114  				tscal = 1 / (smlnum * tmax)
   115  				for j := 0; j < n; j++ {
   116  					if cnorm[j] <= math.MaxFloat64 {
   117  						cnorm[j] *= tscal
   118  					} else {
   119  						// Recompute the 1-norm without introducing Infinity in
   120  						// the summation.
   121  						cnorm[j] = 0
   122  						if upper {
   123  							for i := 0; i < j; i++ {
   124  								cnorm[j] += tscal * math.Abs(a[i*lda+j])
   125  							}
   126  						} else {
   127  							for i := j + 1; i < n; i++ {
   128  								cnorm[j] += tscal * math.Abs(a[i*lda+j])
   129  							}
   130  						}
   131  					}
   132  				}
   133  			} else {
   134  				// At least one entry of A is not a valid floating-point entry.
   135  				// Rely on Dtrsv to propagate Inf and NaN.
   136  				bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
   137  				return
   138  			}
   139  		}
   140  	}
   141  
   142  	// Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
   143  	j := bi.Idamax(n, x, 1)
   144  	xmax := math.Abs(x[j])
   145  	xbnd := xmax
   146  	var grow float64
   147  	var jfirst, jlast, jinc int
   148  	if trans == blas.NoTrans {
   149  		if upper {
   150  			jfirst = n - 1
   151  			jlast = -1
   152  			jinc = -1
   153  		} else {
   154  			jfirst = 0
   155  			jlast = n
   156  			jinc = 1
   157  		}
   158  		// Compute the growth in A * x = b.
   159  		if tscal != 1 {
   160  			grow = 0
   161  			goto Solve
   162  		}
   163  		if nonUnit {
   164  			grow = 1 / math.Max(xbnd, smlnum)
   165  			xbnd = grow
   166  			for j := jfirst; j != jlast; j += jinc {
   167  				if grow <= smlnum {
   168  					goto Solve
   169  				}
   170  				tjj := math.Abs(a[j*lda+j])
   171  				xbnd = math.Min(xbnd, math.Min(1, tjj)*grow)
   172  				if tjj+cnorm[j] >= smlnum {
   173  					grow *= tjj / (tjj + cnorm[j])
   174  				} else {
   175  					grow = 0
   176  				}
   177  			}
   178  			grow = xbnd
   179  		} else {
   180  			grow = math.Min(1, 1/math.Max(xbnd, smlnum))
   181  			for j := jfirst; j != jlast; j += jinc {
   182  				if grow <= smlnum {
   183  					goto Solve
   184  				}
   185  				grow *= 1 / (1 + cnorm[j])
   186  			}
   187  		}
   188  	} else {
   189  		if upper {
   190  			jfirst = 0
   191  			jlast = n
   192  			jinc = 1
   193  		} else {
   194  			jfirst = n - 1
   195  			jlast = -1
   196  			jinc = -1
   197  		}
   198  		if tscal != 1 {
   199  			grow = 0
   200  			goto Solve
   201  		}
   202  		if nonUnit {
   203  			grow = 1 / (math.Max(xbnd, smlnum))
   204  			xbnd = grow
   205  			for j := jfirst; j != jlast; j += jinc {
   206  				if grow <= smlnum {
   207  					goto Solve
   208  				}
   209  				xj := 1 + cnorm[j]
   210  				grow = math.Min(grow, xbnd/xj)
   211  				tjj := math.Abs(a[j*lda+j])
   212  				if xj > tjj {
   213  					xbnd *= tjj / xj
   214  				}
   215  			}
   216  			grow = math.Min(grow, xbnd)
   217  		} else {
   218  			grow = math.Min(1, 1/math.Max(xbnd, smlnum))
   219  			for j := jfirst; j != jlast; j += jinc {
   220  				if grow <= smlnum {
   221  					goto Solve
   222  				}
   223  				xj := 1 + cnorm[j]
   224  				grow /= xj
   225  			}
   226  		}
   227  	}
   228  
   229  Solve:
   230  	if grow*tscal > smlnum {
   231  		// Use the Level 2 BLAS solve if the reciprocal of the bound on
   232  		// elements of X is not too small.
   233  		bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
   234  		if tscal != 1 {
   235  			bi.Dscal(n, 1/tscal, cnorm, 1)
   236  		}
   237  		return scale
   238  	}
   239  
   240  	// Use a Level 1 BLAS solve, scaling intermediate results.
   241  	if xmax > bignum {
   242  		scale = bignum / xmax
   243  		bi.Dscal(n, scale, x, 1)
   244  		xmax = bignum
   245  	}
   246  	if trans == blas.NoTrans {
   247  		for j := jfirst; j != jlast; j += jinc {
   248  			xj := math.Abs(x[j])
   249  			var tjj, tjjs float64
   250  			if nonUnit {
   251  				tjjs = a[j*lda+j] * tscal
   252  			} else {
   253  				tjjs = tscal
   254  				if tscal == 1 {
   255  					goto Skip1
   256  				}
   257  			}
   258  			tjj = math.Abs(tjjs)
   259  			if tjj > smlnum {
   260  				if tjj < 1 {
   261  					if xj > tjj*bignum {
   262  						rec := 1 / xj
   263  						bi.Dscal(n, rec, x, 1)
   264  						scale *= rec
   265  						xmax *= rec
   266  					}
   267  				}
   268  				x[j] /= tjjs
   269  				xj = math.Abs(x[j])
   270  			} else if tjj > 0 {
   271  				if xj > tjj*bignum {
   272  					rec := (tjj * bignum) / xj
   273  					if cnorm[j] > 1 {
   274  						rec /= cnorm[j]
   275  					}
   276  					bi.Dscal(n, rec, x, 1)
   277  					scale *= rec
   278  					xmax *= rec
   279  				}
   280  				x[j] /= tjjs
   281  				xj = math.Abs(x[j])
   282  			} else {
   283  				for i := 0; i < n; i++ {
   284  					x[i] = 0
   285  				}
   286  				x[j] = 1
   287  				xj = 1
   288  				scale = 0
   289  				xmax = 0
   290  			}
   291  		Skip1:
   292  			if xj > 1 {
   293  				rec := 1 / xj
   294  				if cnorm[j] > (bignum-xmax)*rec {
   295  					rec *= 0.5
   296  					bi.Dscal(n, rec, x, 1)
   297  					scale *= rec
   298  				}
   299  			} else if xj*cnorm[j] > bignum-xmax {
   300  				bi.Dscal(n, 0.5, x, 1)
   301  				scale *= 0.5
   302  			}
   303  			if upper {
   304  				if j > 0 {
   305  					bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1)
   306  					i := bi.Idamax(j, x, 1)
   307  					xmax = math.Abs(x[i])
   308  				}
   309  			} else {
   310  				if j < n-1 {
   311  					bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1)
   312  					i := j + bi.Idamax(n-j-1, x[j+1:], 1)
   313  					xmax = math.Abs(x[i])
   314  				}
   315  			}
   316  		}
   317  	} else {
   318  		for j := jfirst; j != jlast; j += jinc {
   319  			xj := math.Abs(x[j])
   320  			uscal := tscal
   321  			rec := 1 / math.Max(xmax, 1)
   322  			var tjjs float64
   323  			if cnorm[j] > (bignum-xj)*rec {
   324  				rec *= 0.5
   325  				if nonUnit {
   326  					tjjs = a[j*lda+j] * tscal
   327  				} else {
   328  					tjjs = tscal
   329  				}
   330  				tjj := math.Abs(tjjs)
   331  				if tjj > 1 {
   332  					rec = math.Min(1, rec*tjj)
   333  					uscal /= tjjs
   334  				}
   335  				if rec < 1 {
   336  					bi.Dscal(n, rec, x, 1)
   337  					scale *= rec
   338  					xmax *= rec
   339  				}
   340  			}
   341  			var sumj float64
   342  			if uscal == 1 {
   343  				if upper {
   344  					sumj = bi.Ddot(j, a[j:], lda, x, 1)
   345  				} else if j < n-1 {
   346  					sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1)
   347  				}
   348  			} else {
   349  				if upper {
   350  					for i := 0; i < j; i++ {
   351  						sumj += (a[i*lda+j] * uscal) * x[i]
   352  					}
   353  				} else if j < n {
   354  					for i := j + 1; i < n; i++ {
   355  						sumj += (a[i*lda+j] * uscal) * x[i]
   356  					}
   357  				}
   358  			}
   359  			if uscal == tscal {
   360  				x[j] -= sumj
   361  				xj := math.Abs(x[j])
   362  				var tjjs float64
   363  				if nonUnit {
   364  					tjjs = a[j*lda+j] * tscal
   365  				} else {
   366  					tjjs = tscal
   367  					if tscal == 1 {
   368  						goto Skip2
   369  					}
   370  				}
   371  				tjj := math.Abs(tjjs)
   372  				if tjj > smlnum {
   373  					if tjj < 1 {
   374  						if xj > tjj*bignum {
   375  							rec = 1 / xj
   376  							bi.Dscal(n, rec, x, 1)
   377  							scale *= rec
   378  							xmax *= rec
   379  						}
   380  					}
   381  					x[j] /= tjjs
   382  				} else if tjj > 0 {
   383  					if xj > tjj*bignum {
   384  						rec = (tjj * bignum) / xj
   385  						bi.Dscal(n, rec, x, 1)
   386  						scale *= rec
   387  						xmax *= rec
   388  					}
   389  					x[j] /= tjjs
   390  				} else {
   391  					for i := 0; i < n; i++ {
   392  						x[i] = 0
   393  					}
   394  					x[j] = 1
   395  					scale = 0
   396  					xmax = 0
   397  				}
   398  			} else {
   399  				x[j] = x[j]/tjjs - sumj
   400  			}
   401  		Skip2:
   402  			xmax = math.Max(xmax, math.Abs(x[j]))
   403  		}
   404  	}
   405  	scale /= tscal
   406  	if tscal != 1 {
   407  		bi.Dscal(n, 1/tscal, cnorm, 1)
   408  	}
   409  	return scale
   410  }