github.com/gopherd/gonum@v0.0.4/lapack/gonum/dlatbs.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  )
    13  
    14  // Dlatbs solves a triangular banded system of equations
    15  //  A * x = s*b    if trans == blas.NoTrans
    16  //  Aᵀ * x = s*b  if trans == blas.Trans or blas.ConjTrans
    17  // where A is an upper or lower triangular band matrix, x and b are n-element
    18  // vectors, and s is a scaling factor chosen so that the components of x will be
    19  // less than the overflow threshold.
    20  //
    21  // On entry, x contains the right-hand side b of the triangular system.
    22  // On return, x is overwritten by the solution vector x.
    23  //
    24  // normin specifies whether the cnorm parameter contains the column norms of A on
    25  // entry. If it is true, cnorm[j] contains the norm of the off-diagonal part of
    26  // the j-th column of A. If it is false, the norms will be computed and stored
    27  // in cnorm.
    28  //
    29  // Dlatbs returns the scaling factor s for the triangular system. If the matrix
    30  // A is singular (A[j,j]==0 for some j), then scale is set to 0 and a
    31  // non-trivial solution to A*x = 0 is returned.
    32  //
    33  // Dlatbs is an internal routine. It is exported for testing purposes.
    34  func (Implementation) Dlatbs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n, kd int, ab []float64, ldab int, x, cnorm []float64) (scale float64) {
    35  	noTran := trans == blas.NoTrans
    36  	switch {
    37  	case uplo != blas.Upper && uplo != blas.Lower:
    38  		panic(badUplo)
    39  	case !noTran && trans != blas.Trans && trans != blas.ConjTrans:
    40  		panic(badTrans)
    41  	case diag != blas.NonUnit && diag != blas.Unit:
    42  		panic(badDiag)
    43  	case n < 0:
    44  		panic(nLT0)
    45  	case kd < 0:
    46  		panic(kdLT0)
    47  	case ldab < kd+1:
    48  		panic(badLdA)
    49  	}
    50  
    51  	// Quick return if possible.
    52  	if n == 0 {
    53  		return 0
    54  	}
    55  
    56  	switch {
    57  	case len(ab) < (n-1)*ldab+kd+1:
    58  		panic(shortAB)
    59  	case len(x) < n:
    60  		panic(shortX)
    61  	case len(cnorm) < n:
    62  		panic(shortCNorm)
    63  	}
    64  
    65  	// Parameters to control overflow.
    66  	smlnum := dlamchS / dlamchP
    67  	bignum := 1 / smlnum
    68  
    69  	bi := blas64.Implementation()
    70  	kld := max(1, ldab-1)
    71  	if !normin {
    72  		// Compute the 1-norm of each column, not including the diagonal.
    73  		if uplo == blas.Upper {
    74  			for j := 0; j < n; j++ {
    75  				jlen := min(j, kd)
    76  				if jlen > 0 {
    77  					cnorm[j] = bi.Dasum(jlen, ab[(j-jlen)*ldab+jlen:], kld)
    78  				} else {
    79  					cnorm[j] = 0
    80  				}
    81  			}
    82  		} else {
    83  			for j := 0; j < n; j++ {
    84  				jlen := min(n-j-1, kd)
    85  				if jlen > 0 {
    86  					cnorm[j] = bi.Dasum(jlen, ab[(j+1)*ldab+kd-1:], kld)
    87  				} else {
    88  					cnorm[j] = 0
    89  				}
    90  			}
    91  		}
    92  	}
    93  
    94  	// Set up indices and increments for loops below.
    95  	var (
    96  		jFirst, jLast, jInc int
    97  		maind               int
    98  	)
    99  	if noTran {
   100  		if uplo == blas.Upper {
   101  			jFirst = n - 1
   102  			jLast = -1
   103  			jInc = -1
   104  			maind = 0
   105  		} else {
   106  			jFirst = 0
   107  			jLast = n
   108  			jInc = 1
   109  			maind = kd
   110  		}
   111  	} else {
   112  		if uplo == blas.Upper {
   113  			jFirst = 0
   114  			jLast = n
   115  			jInc = 1
   116  			maind = 0
   117  		} else {
   118  			jFirst = n - 1
   119  			jLast = -1
   120  			jInc = -1
   121  			maind = kd
   122  		}
   123  	}
   124  
   125  	// Scale the column norms by tscal if the maximum element in cnorm is
   126  	// greater than bignum.
   127  	tmax := cnorm[bi.Idamax(n, cnorm, 1)]
   128  	tscal := 1.0
   129  	if tmax > bignum {
   130  		tscal = 1 / (smlnum * tmax)
   131  		bi.Dscal(n, tscal, cnorm, 1)
   132  	}
   133  
   134  	// Compute a bound on the computed solution vector to see if the Level 2
   135  	// BLAS routine Dtbsv can be used.
   136  
   137  	xMax := math.Abs(x[bi.Idamax(n, x, 1)])
   138  	xBnd := xMax
   139  	grow := 0.0
   140  	// Compute the growth only if the maximum element in cnorm is NOT greater
   141  	// than bignum.
   142  	if tscal != 1 {
   143  		goto skipComputeGrow
   144  	}
   145  	if noTran {
   146  		// Compute the growth in A * x = b.
   147  		if diag == blas.NonUnit {
   148  			// A is non-unit triangular.
   149  			//
   150  			// Compute grow = 1/G_j and xBnd = 1/M_j.
   151  			// Initially, G_0 = max{x(i), i=1,...,n}.
   152  			grow = 1 / math.Max(xBnd, smlnum)
   153  			xBnd = grow
   154  			for j := jFirst; j != jLast; j += jInc {
   155  				if grow <= smlnum {
   156  					// Exit the loop because the growth factor is too small.
   157  					goto skipComputeGrow
   158  				}
   159  				// M_j = G_{j-1} / abs(A[j,j])
   160  				tjj := math.Abs(ab[j*ldab+maind])
   161  				xBnd = math.Min(xBnd, math.Min(1, tjj)*grow)
   162  				if tjj+cnorm[j] >= smlnum {
   163  					// G_j = G_{j-1}*( 1 + cnorm[j] / abs(A[j,j]) )
   164  					grow *= tjj / (tjj + cnorm[j])
   165  				} else {
   166  					// G_j could overflow, set grow to 0.
   167  					grow = 0
   168  				}
   169  			}
   170  			grow = xBnd
   171  		} else {
   172  			// A is unit triangular.
   173  			//
   174  			// Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}.
   175  			grow = math.Min(1, 1/math.Max(xBnd, smlnum))
   176  			for j := jFirst; j != jLast; j += jInc {
   177  				if grow <= smlnum {
   178  					// Exit the loop because the growth factor is too small.
   179  					goto skipComputeGrow
   180  				}
   181  				// G_j = G_{j-1}*( 1 + cnorm[j] )
   182  				grow /= 1 + cnorm[j]
   183  			}
   184  		}
   185  	} else {
   186  		// Compute the growth in Aᵀ * x = b.
   187  		if diag == blas.NonUnit {
   188  			// A is non-unit triangular.
   189  			//
   190  			// Compute grow = 1/G_j and xBnd = 1/M_j.
   191  			// Initially, G_0 = max{x(i), i=1,...,n}.
   192  			grow = 1 / math.Max(xBnd, smlnum)
   193  			xBnd = grow
   194  			for j := jFirst; j != jLast; j += jInc {
   195  				if grow <= smlnum {
   196  					// Exit the loop because the growth factor is too small.
   197  					goto skipComputeGrow
   198  				}
   199  				// G_j = max( G_{j-1}, M_{j-1}*( 1 + cnorm[j] ) )
   200  				xj := 1 + cnorm[j]
   201  				grow = math.Min(grow, xBnd/xj)
   202  				// M_j = M_{j-1}*( 1 + cnorm[j] ) / abs(A[j,j])
   203  				tjj := math.Abs(ab[j*ldab+maind])
   204  				if xj > tjj {
   205  					xBnd *= tjj / xj
   206  				}
   207  			}
   208  			grow = math.Min(grow, xBnd)
   209  		} else {
   210  			// A is unit triangular.
   211  			//
   212  			// Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}.
   213  			grow = math.Min(1, 1/math.Max(xBnd, smlnum))
   214  			for j := jFirst; j != jLast; j += jInc {
   215  				if grow <= smlnum {
   216  					// Exit the loop because the growth factor is too small.
   217  					goto skipComputeGrow
   218  				}
   219  				// G_j = G_{j-1}*( 1 + cnorm[j] )
   220  				grow /= 1 + cnorm[j]
   221  			}
   222  		}
   223  	}
   224  skipComputeGrow:
   225  
   226  	if grow*tscal > smlnum {
   227  		// The reciprocal of the bound on elements of X is not too small, use
   228  		// the Level 2 BLAS solve.
   229  		bi.Dtbsv(uplo, trans, diag, n, kd, ab, ldab, x, 1)
   230  		// Scale the column norms by 1/tscal for return.
   231  		if tscal != 1 {
   232  			bi.Dscal(n, 1/tscal, cnorm, 1)
   233  		}
   234  		return 1
   235  	}
   236  
   237  	// Use a Level 1 BLAS solve, scaling intermediate results.
   238  
   239  	scale = 1
   240  	if xMax > bignum {
   241  		// Scale x so that its components are less than or equal to bignum in
   242  		// absolute value.
   243  		scale = bignum / xMax
   244  		bi.Dscal(n, scale, x, 1)
   245  		xMax = bignum
   246  	}
   247  
   248  	if noTran {
   249  		// Solve A * x = b.
   250  		for j := jFirst; j != jLast; j += jInc {
   251  			// Compute x[j] = b[j] / A[j,j], scaling x if necessary.
   252  			xj := math.Abs(x[j])
   253  			tjjs := tscal
   254  			if diag == blas.NonUnit {
   255  				tjjs *= ab[j*ldab+maind]
   256  			}
   257  			tjj := math.Abs(tjjs)
   258  			switch {
   259  			case tjj > smlnum:
   260  				// smlnum < abs(A[j,j])
   261  				if tjj < 1 && xj > tjj*bignum {
   262  					// Scale x by 1/b[j].
   263  					rec := 1 / xj
   264  					bi.Dscal(n, rec, x, 1)
   265  					scale *= rec
   266  					xMax *= rec
   267  				}
   268  				x[j] /= tjjs
   269  				xj = math.Abs(x[j])
   270  			case tjj > 0:
   271  				// 0 < abs(A[j,j]) <= smlnum
   272  				if xj > tjj*bignum {
   273  					// Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum to avoid
   274  					// overflow when dividing by A[j,j].
   275  					rec := tjj * bignum / xj
   276  					if cnorm[j] > 1 {
   277  						// Scale by 1/cnorm[j] to avoid overflow when
   278  						// multiplying x[j] times column j.
   279  						rec /= cnorm[j]
   280  					}
   281  					bi.Dscal(n, rec, x, 1)
   282  					scale *= rec
   283  					xMax *= rec
   284  				}
   285  				x[j] /= tjjs
   286  				xj = math.Abs(x[j])
   287  			default:
   288  				// A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and
   289  				// compute a solution to A*x = 0.
   290  				for i := range x[:n] {
   291  					x[i] = 0
   292  				}
   293  				x[j] = 1
   294  				xj = 1
   295  				scale = 0
   296  				xMax = 0
   297  			}
   298  
   299  			// Scale x if necessary to avoid overflow when adding a multiple of
   300  			// column j of A.
   301  			switch {
   302  			case xj > 1:
   303  				rec := 1 / xj
   304  				if cnorm[j] > (bignum-xMax)*rec {
   305  					// Scale x by 1/(2*abs(x[j])).
   306  					rec *= 0.5
   307  					bi.Dscal(n, rec, x, 1)
   308  					scale *= rec
   309  				}
   310  			case xj*cnorm[j] > bignum-xMax:
   311  				// Scale x by 1/2.
   312  				bi.Dscal(n, 0.5, x, 1)
   313  				scale *= 0.5
   314  			}
   315  
   316  			if uplo == blas.Upper {
   317  				if j > 0 {
   318  					// Compute the update
   319  					//  x[max(0,j-kd):j] := x[max(0,j-kd):j] - x[j] * A[max(0,j-kd):j,j]
   320  					jlen := min(j, kd)
   321  					if jlen > 0 {
   322  						bi.Daxpy(jlen, -x[j]*tscal, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1)
   323  					}
   324  					i := bi.Idamax(j, x, 1)
   325  					xMax = math.Abs(x[i])
   326  				}
   327  			} else if j < n-1 {
   328  				// Compute the update
   329  				//  x[j+1:min(j+kd,n)] := x[j+1:min(j+kd,n)] - x[j] * A[j+1:min(j+kd,n),j]
   330  				jlen := min(kd, n-j-1)
   331  				if jlen > 0 {
   332  					bi.Daxpy(jlen, -x[j]*tscal, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1)
   333  				}
   334  				i := j + 1 + bi.Idamax(n-j-1, x[j+1:], 1)
   335  				xMax = math.Abs(x[i])
   336  			}
   337  		}
   338  	} else {
   339  		// Solve Aᵀ * x = b.
   340  		for j := jFirst; j != jLast; j += jInc {
   341  			// Compute x[j] = b[j] - sum A[k,j]*x[k].
   342  			//                       k!=j
   343  			xj := math.Abs(x[j])
   344  			tjjs := tscal
   345  			if diag == blas.NonUnit {
   346  				tjjs *= ab[j*ldab+maind]
   347  			}
   348  			tjj := math.Abs(tjjs)
   349  			rec := 1 / math.Max(1, xMax)
   350  			uscal := tscal
   351  			if cnorm[j] > (bignum-xj)*rec {
   352  				// If x[j] could overflow, scale x by 1/(2*xMax).
   353  				rec *= 0.5
   354  				if tjj > 1 {
   355  					// Divide by A[j,j] when scaling x if A[j,j] > 1.
   356  					rec = math.Min(1, rec*tjj)
   357  					uscal /= tjjs
   358  				}
   359  				if rec < 1 {
   360  					bi.Dscal(n, rec, x, 1)
   361  					scale *= rec
   362  					xMax *= rec
   363  				}
   364  			}
   365  
   366  			var sumj float64
   367  			if uscal == 1 {
   368  				// If the scaling needed for A in the dot product is 1, call
   369  				// Ddot to perform the dot product...
   370  				if uplo == blas.Upper {
   371  					jlen := min(j, kd)
   372  					if jlen > 0 {
   373  						sumj = bi.Ddot(jlen, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1)
   374  					}
   375  				} else {
   376  					jlen := min(n-j-1, kd)
   377  					if jlen > 0 {
   378  						sumj = bi.Ddot(jlen, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1)
   379  					}
   380  				}
   381  			} else {
   382  				// ...otherwise, use in-line code for the dot product.
   383  				if uplo == blas.Upper {
   384  					jlen := min(j, kd)
   385  					for i := 0; i < jlen; i++ {
   386  						sumj += (ab[(j-jlen+i)*ldab+jlen-i] * uscal) * x[j-jlen+i]
   387  					}
   388  				} else {
   389  					jlen := min(n-j-1, kd)
   390  					for i := 0; i < jlen; i++ {
   391  						sumj += (ab[(j+1+i)*ldab+kd-1-i] * uscal) * x[j+i+1]
   392  					}
   393  				}
   394  			}
   395  
   396  			if uscal == tscal {
   397  				// Compute x[j] := ( x[j] - sumj ) / A[j,j]
   398  				// if 1/A[j,j] was not used to scale the dot product.
   399  				x[j] -= sumj
   400  				xj = math.Abs(x[j])
   401  				// Compute x[j] = x[j] / A[j,j], scaling if necessary.
   402  				// Note: the reference implementation skips this step for blas.Unit matrices
   403  				// when tscal is equal to 1 but it complicates the logic and only saves
   404  				// the comparison and division in the first switch-case. Not skipping it
   405  				// is also consistent with the NoTrans case above.
   406  				switch {
   407  				case tjj > smlnum:
   408  					// smlnum < abs(A[j,j]):
   409  					if tjj < 1 && xj > tjj*bignum {
   410  						// Scale x by 1/abs(x[j]).
   411  						rec := 1 / xj
   412  						bi.Dscal(n, rec, x, 1)
   413  						scale *= rec
   414  						xMax *= rec
   415  					}
   416  					x[j] /= tjjs
   417  				case tjj > 0:
   418  					// 0 < abs(A[j,j]) <= smlnum:
   419  					if xj > tjj*bignum {
   420  						// Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum.
   421  						rec := (tjj * bignum) / xj
   422  						bi.Dscal(n, rec, x, 1)
   423  						scale *= rec
   424  						xMax *= rec
   425  					}
   426  					x[j] /= tjjs
   427  				default:
   428  					// A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and
   429  					// compute a solution Aᵀ * x = 0.
   430  					for i := range x[:n] {
   431  						x[i] = 0
   432  					}
   433  					x[j] = 1
   434  					scale = 0
   435  					xMax = 0
   436  				}
   437  			} else {
   438  				// Compute x[j] := x[j] / A[j,j] - sumj
   439  				// if the dot product has already been divided by 1/A[j,j].
   440  				x[j] = x[j]/tjjs - sumj
   441  			}
   442  			xMax = math.Max(xMax, math.Abs(x[j]))
   443  		}
   444  		scale /= tscal
   445  	}
   446  
   447  	// Scale the column norms by 1/tscal for return.
   448  	if tscal != 1 {
   449  		bi.Dscal(n, 1/tscal, cnorm, 1)
   450  	}
   451  	return scale
   452  }