github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dlatrs.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package native
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gonum/blas"
    11  	"github.com/gonum/blas/blas64"
    12  )
    13  
    14  // Dlatrs solves a triangular system of equations scaled to prevent overflow. It
    15  // solves
    16  //  A * x = scale * b if trans == blas.NoTrans
    17  //  A^T * x = scale * b if trans == blas.Trans
    18  // where the scale s is set for numeric stability.
    19  //
    20  // A is an n×n triangular matrix. On entry, the slice x contains the values of
    21  // of b, and on exit it contains the solution vector x.
    22  //
    23  // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal
    24  // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater
    25  // than or equal to the infinity norm, and greater than or equal to the one-norm
    26  // otherwise. If normin == false, then cnorm is treated as an output, and is set
    27  // to contain the 1-norm of the off-diagonal part of the j^th column of A.
    28  //
    29  // Dlatrs is an internal routine. It is exported for testing purposes.
    30  func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) {
    31  	if uplo != blas.Upper && uplo != blas.Lower {
    32  		panic(badUplo)
    33  	}
    34  	if trans != blas.Trans && trans != blas.NoTrans {
    35  		panic(badTrans)
    36  	}
    37  	if diag != blas.Unit && diag != blas.NonUnit {
    38  		panic(badDiag)
    39  	}
    40  	upper := uplo == blas.Upper
    41  	noTrans := trans == blas.NoTrans
    42  	nonUnit := diag == blas.NonUnit
    43  
    44  	if n < 0 {
    45  		panic(nLT0)
    46  	}
    47  	checkMatrix(n, n, a, lda)
    48  	checkVector(n, x, 1)
    49  	checkVector(n, cnorm, 1)
    50  
    51  	if n == 0 {
    52  		return 0
    53  	}
    54  	smlnum := dlamchS / dlamchP
    55  	bignum := 1 / smlnum
    56  	scale = 1
    57  	bi := blas64.Implementation()
    58  	if !normin {
    59  		if upper {
    60  			cnorm[0] = 0
    61  			for j := 1; j < n; j++ {
    62  				cnorm[j] = bi.Dasum(j, a[j:], lda)
    63  			}
    64  		} else {
    65  			for j := 0; j < n-1; j++ {
    66  				cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda)
    67  			}
    68  			cnorm[n-1] = 0
    69  		}
    70  	}
    71  	// Scale the column norms by tscal if the maximum element in cnorm is greater than bignum.
    72  	imax := bi.Idamax(n, cnorm, 1)
    73  	tmax := cnorm[imax]
    74  	var tscal float64
    75  	if tmax <= bignum {
    76  		tscal = 1
    77  	} else {
    78  		tscal = 1 / (smlnum * tmax)
    79  		bi.Dscal(n, tscal, cnorm, 1)
    80  	}
    81  
    82  	// Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
    83  	j := bi.Idamax(n, x, 1)
    84  	xmax := math.Abs(x[j])
    85  	xbnd := xmax
    86  	var grow float64
    87  	var jfirst, jlast, jinc int
    88  	if noTrans {
    89  		if upper {
    90  			jfirst = n - 1
    91  			jlast = -1
    92  			jinc = -1
    93  		} else {
    94  			jfirst = 0
    95  			jlast = n
    96  			jinc = 1
    97  		}
    98  		// Compute the growth in A * x = b.
    99  		if tscal != 1 {
   100  			grow = 0
   101  			goto Solve
   102  		}
   103  		if nonUnit {
   104  			grow = 1 / math.Max(xbnd, smlnum)
   105  			xbnd = grow
   106  			for j := jfirst; j != jlast; j += jinc {
   107  				if grow <= smlnum {
   108  					goto Solve
   109  				}
   110  				tjj := math.Abs(a[j*lda+j])
   111  				xbnd = math.Min(xbnd, math.Min(1, tjj)*grow)
   112  				if tjj+cnorm[j] >= smlnum {
   113  					grow *= tjj / (tjj + cnorm[j])
   114  				} else {
   115  					grow = 0
   116  				}
   117  			}
   118  			grow = xbnd
   119  		} else {
   120  			grow = math.Min(1, 1/math.Max(xbnd, smlnum))
   121  			for j := jfirst; j != jlast; j += jinc {
   122  				if grow <= smlnum {
   123  					goto Solve
   124  				}
   125  				grow *= 1 / (1 + cnorm[j])
   126  			}
   127  		}
   128  	} else {
   129  		if upper {
   130  			jfirst = 0
   131  			jlast = n
   132  			jinc = 1
   133  		} else {
   134  			jfirst = n - 1
   135  			jlast = -1
   136  			jinc = -1
   137  		}
   138  		if tscal != 1 {
   139  			grow = 0
   140  			goto Solve
   141  		}
   142  		if nonUnit {
   143  			grow = 1 / (math.Max(xbnd, smlnum))
   144  			xbnd = grow
   145  			for j := jfirst; j != jlast; j += jinc {
   146  				if grow <= smlnum {
   147  					goto Solve
   148  				}
   149  				xj := 1 + cnorm[j]
   150  				grow = math.Min(grow, xbnd/xj)
   151  				tjj := math.Abs(a[j*lda+j])
   152  				if xj > tjj {
   153  					xbnd *= tjj / xj
   154  				}
   155  			}
   156  			grow = math.Min(grow, xbnd)
   157  		} else {
   158  			grow = math.Min(1, 1/math.Max(xbnd, smlnum))
   159  			for j := jfirst; j != jlast; j += jinc {
   160  				if grow <= smlnum {
   161  					goto Solve
   162  				}
   163  				xj := 1 + cnorm[j]
   164  				grow /= xj
   165  			}
   166  		}
   167  	}
   168  
   169  Solve:
   170  	if grow*tscal > smlnum {
   171  		// Use the Level 2 BLAS solve if the reciprocal of the bound on
   172  		// elements of X is not too small.
   173  		bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
   174  		if tscal != 1 {
   175  			bi.Dscal(n, 1/tscal, cnorm, 1)
   176  		}
   177  		return scale
   178  	}
   179  
   180  	// Use a Level 1 BLAS solve, scaling intermediate results.
   181  	if xmax > bignum {
   182  		scale = bignum / xmax
   183  		bi.Dscal(n, scale, x, 1)
   184  		xmax = bignum
   185  	}
   186  	if noTrans {
   187  		for j := jfirst; j != jlast; j += jinc {
   188  			xj := math.Abs(x[j])
   189  			var tjj, tjjs float64
   190  			if nonUnit {
   191  				tjjs = a[j*lda+j] * tscal
   192  			} else {
   193  				tjjs = tscal
   194  				if tscal == 1 {
   195  					goto Skip1
   196  				}
   197  			}
   198  			tjj = math.Abs(tjjs)
   199  			if tjj > smlnum {
   200  				if tjj < 1 {
   201  					if xj > tjj*bignum {
   202  						rec := 1 / xj
   203  						bi.Dscal(n, rec, x, 1)
   204  						scale *= rec
   205  						xmax *= rec
   206  					}
   207  				}
   208  				x[j] /= tjjs
   209  				xj = math.Abs(x[j])
   210  			} else if tjj > 0 {
   211  				if xj > tjj*bignum {
   212  					rec := (tjj * bignum) / xj
   213  					if cnorm[j] > 1 {
   214  						rec /= cnorm[j]
   215  					}
   216  					bi.Dscal(n, rec, x, 1)
   217  					scale *= rec
   218  					xmax *= rec
   219  				}
   220  				x[j] /= tjjs
   221  				xj = math.Abs(x[j])
   222  			} else {
   223  				for i := 0; i < n; i++ {
   224  					x[i] = 0
   225  				}
   226  				x[j] = 1
   227  				xj = 1
   228  				scale = 0
   229  				xmax = 0
   230  			}
   231  		Skip1:
   232  			if xj > 1 {
   233  				rec := 1 / xj
   234  				if cnorm[j] > (bignum-xmax)*rec {
   235  					rec *= 0.5
   236  					bi.Dscal(n, rec, x, 1)
   237  					scale *= rec
   238  				}
   239  			} else if xj*cnorm[j] > bignum-xmax {
   240  				bi.Dscal(n, 0.5, x, 1)
   241  				scale *= 0.5
   242  			}
   243  			if upper {
   244  				if j > 0 {
   245  					bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1)
   246  					i := bi.Idamax(j, x, 1)
   247  					xmax = math.Abs(x[i])
   248  				}
   249  			} else {
   250  				if j < n-1 {
   251  					bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1)
   252  					i := j + bi.Idamax(n-j-1, x[j+1:], 1)
   253  					xmax = math.Abs(x[i])
   254  				}
   255  			}
   256  		}
   257  	} else {
   258  		for j := jfirst; j != jlast; j += jinc {
   259  			xj := math.Abs(x[j])
   260  			uscal := tscal
   261  			rec := 1 / math.Max(xmax, 1)
   262  			var tjjs float64
   263  			if cnorm[j] > (bignum-xj)*rec {
   264  				rec *= 0.5
   265  				if nonUnit {
   266  					tjjs = a[j*lda+j] * tscal
   267  				} else {
   268  					tjjs = tscal
   269  				}
   270  				tjj := math.Abs(tjjs)
   271  				if tjj > 1 {
   272  					rec = math.Min(1, rec*tjj)
   273  					uscal /= tjjs
   274  				}
   275  				if rec < 1 {
   276  					bi.Dscal(n, rec, x, 1)
   277  					scale *= rec
   278  					xmax *= rec
   279  				}
   280  			}
   281  			var sumj float64
   282  			if uscal == 1 {
   283  				if upper {
   284  					sumj = bi.Ddot(j, a[j:], lda, x, 1)
   285  				} else if j < n-1 {
   286  					sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1)
   287  				}
   288  			} else {
   289  				if upper {
   290  					for i := 0; i < j; i++ {
   291  						sumj += (a[i*lda+j] * uscal) * x[i]
   292  					}
   293  				} else if j < n {
   294  					for i := j + 1; i < n; i++ {
   295  						sumj += (a[i*lda+j] * uscal) * x[i]
   296  					}
   297  				}
   298  			}
   299  			if uscal == tscal {
   300  				x[j] -= sumj
   301  				xj := math.Abs(x[j])
   302  				var tjjs float64
   303  				if nonUnit {
   304  					tjjs = a[j*lda+j] * tscal
   305  				} else {
   306  					tjjs = tscal
   307  					if tscal == 1 {
   308  						goto Skip2
   309  					}
   310  				}
   311  				tjj := math.Abs(tjjs)
   312  				if tjj > smlnum {
   313  					if tjj < 1 {
   314  						if xj > tjj*bignum {
   315  							rec = 1 / xj
   316  							bi.Dscal(n, rec, x, 1)
   317  							scale *= rec
   318  							xmax *= rec
   319  						}
   320  					}
   321  					x[j] /= tjjs
   322  				} else if tjj > 0 {
   323  					if xj > tjj*bignum {
   324  						rec = (tjj * bignum) / xj
   325  						bi.Dscal(n, rec, x, 1)
   326  						scale *= rec
   327  						xmax *= rec
   328  					}
   329  					x[j] /= tjjs
   330  				} else {
   331  					for i := 0; i < n; i++ {
   332  						x[i] = 0
   333  					}
   334  					x[j] = 1
   335  					scale = 0
   336  					xmax = 0
   337  				}
   338  			} else {
   339  				x[j] = x[j]/tjjs - sumj
   340  			}
   341  		Skip2:
   342  			xmax = math.Max(xmax, math.Abs(x[j]))
   343  		}
   344  	}
   345  	scale /= tscal
   346  	if tscal != 1 {
   347  		bi.Dscal(n, 1/tscal, cnorm, 1)
   348  	}
   349  	return scale
   350  }