gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlatbs.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  )
    13  
    14  // Dlatbs solves a triangular banded system of equations
    15  //
    16  //	A * x = s*b    if trans == blas.NoTrans
    17  //	Aᵀ * x = s*b  if trans == blas.Trans or blas.ConjTrans
    18  //
    19  // where A is an upper or lower triangular band matrix, x and b are n-element
    20  // vectors, and s is a scaling factor chosen so that the components of x will be
    21  // less than the overflow threshold.
    22  //
    23  // On entry, x contains the right-hand side b of the triangular system.
    24  // On return, x is overwritten by the solution vector x.
    25  //
    26  // normin specifies whether the cnorm parameter contains the column norms of A on
    27  // entry. If it is true, cnorm[j] contains the norm of the off-diagonal part of
    28  // the j-th column of A. If it is false, the norms will be computed and stored
    29  // in cnorm.
    30  //
    31  // Dlatbs returns the scaling factor s for the triangular system. If the matrix
    32  // A is singular (A[j,j]==0 for some j), then scale is set to 0 and a
    33  // non-trivial solution to A*x = 0 is returned.
    34  //
    35  // Dlatbs is an internal routine. It is exported for testing purposes.
    36  func (Implementation) Dlatbs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n, kd int, ab []float64, ldab int, x, cnorm []float64) (scale float64) {
    37  	noTran := trans == blas.NoTrans
    38  	switch {
    39  	case uplo != blas.Upper && uplo != blas.Lower:
    40  		panic(badUplo)
    41  	case !noTran && trans != blas.Trans && trans != blas.ConjTrans:
    42  		panic(badTrans)
    43  	case diag != blas.NonUnit && diag != blas.Unit:
    44  		panic(badDiag)
    45  	case n < 0:
    46  		panic(nLT0)
    47  	case kd < 0:
    48  		panic(kdLT0)
    49  	case ldab < kd+1:
    50  		panic(badLdA)
    51  	}
    52  
    53  	// Quick return if possible.
    54  	if n == 0 {
    55  		return 1
    56  	}
    57  
    58  	switch {
    59  	case len(ab) < (n-1)*ldab+kd+1:
    60  		panic(shortAB)
    61  	case len(x) < n:
    62  		panic(shortX)
    63  	case len(cnorm) < n:
    64  		panic(shortCNorm)
    65  	}
    66  
    67  	// Parameters to control overflow.
    68  	smlnum := dlamchS / dlamchP
    69  	bignum := 1 / smlnum
    70  
    71  	bi := blas64.Implementation()
    72  	kld := max(1, ldab-1)
    73  	if !normin {
    74  		// Compute the 1-norm of each column, not including the diagonal.
    75  		if uplo == blas.Upper {
    76  			for j := 0; j < n; j++ {
    77  				jlen := min(j, kd)
    78  				if jlen > 0 {
    79  					cnorm[j] = bi.Dasum(jlen, ab[(j-jlen)*ldab+jlen:], kld)
    80  				} else {
    81  					cnorm[j] = 0
    82  				}
    83  			}
    84  		} else {
    85  			for j := 0; j < n; j++ {
    86  				jlen := min(n-j-1, kd)
    87  				if jlen > 0 {
    88  					cnorm[j] = bi.Dasum(jlen, ab[(j+1)*ldab+kd-1:], kld)
    89  				} else {
    90  					cnorm[j] = 0
    91  				}
    92  			}
    93  		}
    94  	}
    95  
    96  	// Set up indices and increments for loops below.
    97  	var (
    98  		jFirst, jLast, jInc int
    99  		maind               int
   100  	)
   101  	if noTran {
   102  		if uplo == blas.Upper {
   103  			jFirst = n - 1
   104  			jLast = -1
   105  			jInc = -1
   106  			maind = 0
   107  		} else {
   108  			jFirst = 0
   109  			jLast = n
   110  			jInc = 1
   111  			maind = kd
   112  		}
   113  	} else {
   114  		if uplo == blas.Upper {
   115  			jFirst = 0
   116  			jLast = n
   117  			jInc = 1
   118  			maind = 0
   119  		} else {
   120  			jFirst = n - 1
   121  			jLast = -1
   122  			jInc = -1
   123  			maind = kd
   124  		}
   125  	}
   126  
   127  	// Scale the column norms by tscal if the maximum element in cnorm is
   128  	// greater than bignum.
   129  	tmax := cnorm[bi.Idamax(n, cnorm, 1)]
   130  	tscal := 1.0
   131  	if tmax > bignum {
   132  		tscal = 1 / (smlnum * tmax)
   133  		bi.Dscal(n, tscal, cnorm, 1)
   134  	}
   135  
   136  	// Compute a bound on the computed solution vector to see if the Level 2
   137  	// BLAS routine Dtbsv can be used.
   138  
   139  	xMax := math.Abs(x[bi.Idamax(n, x, 1)])
   140  	xBnd := xMax
   141  	grow := 0.0
   142  	// Compute the growth only if the maximum element in cnorm is NOT greater
   143  	// than bignum.
   144  	if tscal != 1 {
   145  		goto skipComputeGrow
   146  	}
   147  	if noTran {
   148  		// Compute the growth in A * x = b.
   149  		if diag == blas.NonUnit {
   150  			// A is non-unit triangular.
   151  			//
   152  			// Compute grow = 1/G_j and xBnd = 1/M_j.
   153  			// Initially, G_0 = max{x(i), i=1,...,n}.
   154  			grow = 1 / math.Max(xBnd, smlnum)
   155  			xBnd = grow
   156  			for j := jFirst; j != jLast; j += jInc {
   157  				if grow <= smlnum {
   158  					// Exit the loop because the growth factor is too small.
   159  					goto skipComputeGrow
   160  				}
   161  				// M_j = G_{j-1} / abs(A[j,j])
   162  				tjj := math.Abs(ab[j*ldab+maind])
   163  				xBnd = math.Min(xBnd, math.Min(1, tjj)*grow)
   164  				if tjj+cnorm[j] >= smlnum {
   165  					// G_j = G_{j-1}*( 1 + cnorm[j] / abs(A[j,j]) )
   166  					grow *= tjj / (tjj + cnorm[j])
   167  				} else {
   168  					// G_j could overflow, set grow to 0.
   169  					grow = 0
   170  				}
   171  			}
   172  			grow = xBnd
   173  		} else {
   174  			// A is unit triangular.
   175  			//
   176  			// Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}.
   177  			grow = math.Min(1, 1/math.Max(xBnd, smlnum))
   178  			for j := jFirst; j != jLast; j += jInc {
   179  				if grow <= smlnum {
   180  					// Exit the loop because the growth factor is too small.
   181  					goto skipComputeGrow
   182  				}
   183  				// G_j = G_{j-1}*( 1 + cnorm[j] )
   184  				grow /= 1 + cnorm[j]
   185  			}
   186  		}
   187  	} else {
   188  		// Compute the growth in Aᵀ * x = b.
   189  		if diag == blas.NonUnit {
   190  			// A is non-unit triangular.
   191  			//
   192  			// Compute grow = 1/G_j and xBnd = 1/M_j.
   193  			// Initially, G_0 = max{x(i), i=1,...,n}.
   194  			grow = 1 / math.Max(xBnd, smlnum)
   195  			xBnd = grow
   196  			for j := jFirst; j != jLast; j += jInc {
   197  				if grow <= smlnum {
   198  					// Exit the loop because the growth factor is too small.
   199  					goto skipComputeGrow
   200  				}
   201  				// G_j = max( G_{j-1}, M_{j-1}*( 1 + cnorm[j] ) )
   202  				xj := 1 + cnorm[j]
   203  				grow = math.Min(grow, xBnd/xj)
   204  				// M_j = M_{j-1}*( 1 + cnorm[j] ) / abs(A[j,j])
   205  				tjj := math.Abs(ab[j*ldab+maind])
   206  				if xj > tjj {
   207  					xBnd *= tjj / xj
   208  				}
   209  			}
   210  			grow = math.Min(grow, xBnd)
   211  		} else {
   212  			// A is unit triangular.
   213  			//
   214  			// Compute grow = 1/G_j, where G_0 = max{x(i), i=1,...,n}.
   215  			grow = math.Min(1, 1/math.Max(xBnd, smlnum))
   216  			for j := jFirst; j != jLast; j += jInc {
   217  				if grow <= smlnum {
   218  					// Exit the loop because the growth factor is too small.
   219  					goto skipComputeGrow
   220  				}
   221  				// G_j = G_{j-1}*( 1 + cnorm[j] )
   222  				grow /= 1 + cnorm[j]
   223  			}
   224  		}
   225  	}
   226  skipComputeGrow:
   227  
   228  	if grow*tscal > smlnum {
   229  		// The reciprocal of the bound on elements of X is not too small, use
   230  		// the Level 2 BLAS solve.
   231  		bi.Dtbsv(uplo, trans, diag, n, kd, ab, ldab, x, 1)
   232  		// Scale the column norms by 1/tscal for return.
   233  		if tscal != 1 {
   234  			bi.Dscal(n, 1/tscal, cnorm, 1)
   235  		}
   236  		return 1
   237  	}
   238  
   239  	// Use a Level 1 BLAS solve, scaling intermediate results.
   240  
   241  	scale = 1
   242  	if xMax > bignum {
   243  		// Scale x so that its components are less than or equal to bignum in
   244  		// absolute value.
   245  		scale = bignum / xMax
   246  		bi.Dscal(n, scale, x, 1)
   247  		xMax = bignum
   248  	}
   249  
   250  	if noTran {
   251  		// Solve A * x = b.
   252  		for j := jFirst; j != jLast; j += jInc {
   253  			// Compute x[j] = b[j] / A[j,j], scaling x if necessary.
   254  			xj := math.Abs(x[j])
   255  			tjjs := tscal
   256  			if diag == blas.NonUnit {
   257  				tjjs *= ab[j*ldab+maind]
   258  			}
   259  			tjj := math.Abs(tjjs)
   260  			switch {
   261  			case tjj > smlnum:
   262  				// smlnum < abs(A[j,j])
   263  				if tjj < 1 && xj > tjj*bignum {
   264  					// Scale x by 1/b[j].
   265  					rec := 1 / xj
   266  					bi.Dscal(n, rec, x, 1)
   267  					scale *= rec
   268  					xMax *= rec
   269  				}
   270  				x[j] /= tjjs
   271  				xj = math.Abs(x[j])
   272  			case tjj > 0:
   273  				// 0 < abs(A[j,j]) <= smlnum
   274  				if xj > tjj*bignum {
   275  					// Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum to avoid
   276  					// overflow when dividing by A[j,j].
   277  					rec := tjj * bignum / xj
   278  					if cnorm[j] > 1 {
   279  						// Scale by 1/cnorm[j] to avoid overflow when
   280  						// multiplying x[j] times column j.
   281  						rec /= cnorm[j]
   282  					}
   283  					bi.Dscal(n, rec, x, 1)
   284  					scale *= rec
   285  					xMax *= rec
   286  				}
   287  				x[j] /= tjjs
   288  				xj = math.Abs(x[j])
   289  			default:
   290  				// A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and
   291  				// compute a solution to A*x = 0.
   292  				for i := range x[:n] {
   293  					x[i] = 0
   294  				}
   295  				x[j] = 1
   296  				xj = 1
   297  				scale = 0
   298  				xMax = 0
   299  			}
   300  
   301  			// Scale x if necessary to avoid overflow when adding a multiple of
   302  			// column j of A.
   303  			switch {
   304  			case xj > 1:
   305  				rec := 1 / xj
   306  				if cnorm[j] > (bignum-xMax)*rec {
   307  					// Scale x by 1/(2*abs(x[j])).
   308  					rec *= 0.5
   309  					bi.Dscal(n, rec, x, 1)
   310  					scale *= rec
   311  				}
   312  			case xj*cnorm[j] > bignum-xMax:
   313  				// Scale x by 1/2.
   314  				bi.Dscal(n, 0.5, x, 1)
   315  				scale *= 0.5
   316  			}
   317  
   318  			if uplo == blas.Upper {
   319  				if j > 0 {
   320  					// Compute the update
   321  					//  x[max(0,j-kd):j] := x[max(0,j-kd):j] - x[j] * A[max(0,j-kd):j,j]
   322  					jlen := min(j, kd)
   323  					if jlen > 0 {
   324  						bi.Daxpy(jlen, -x[j]*tscal, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1)
   325  					}
   326  					i := bi.Idamax(j, x, 1)
   327  					xMax = math.Abs(x[i])
   328  				}
   329  			} else if j < n-1 {
   330  				// Compute the update
   331  				//  x[j+1:min(j+kd,n)] := x[j+1:min(j+kd,n)] - x[j] * A[j+1:min(j+kd,n),j]
   332  				jlen := min(kd, n-j-1)
   333  				if jlen > 0 {
   334  					bi.Daxpy(jlen, -x[j]*tscal, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1)
   335  				}
   336  				i := j + 1 + bi.Idamax(n-j-1, x[j+1:], 1)
   337  				xMax = math.Abs(x[i])
   338  			}
   339  		}
   340  	} else {
   341  		// Solve Aᵀ * x = b.
   342  		for j := jFirst; j != jLast; j += jInc {
   343  			// Compute x[j] = b[j] - sum A[k,j]*x[k].
   344  			//                       k!=j
   345  			xj := math.Abs(x[j])
   346  			tjjs := tscal
   347  			if diag == blas.NonUnit {
   348  				tjjs *= ab[j*ldab+maind]
   349  			}
   350  			tjj := math.Abs(tjjs)
   351  			rec := 1 / math.Max(1, xMax)
   352  			uscal := tscal
   353  			if cnorm[j] > (bignum-xj)*rec {
   354  				// If x[j] could overflow, scale x by 1/(2*xMax).
   355  				rec *= 0.5
   356  				if tjj > 1 {
   357  					// Divide by A[j,j] when scaling x if A[j,j] > 1.
   358  					rec = math.Min(1, rec*tjj)
   359  					uscal /= tjjs
   360  				}
   361  				if rec < 1 {
   362  					bi.Dscal(n, rec, x, 1)
   363  					scale *= rec
   364  					xMax *= rec
   365  				}
   366  			}
   367  
   368  			var sumj float64
   369  			if uscal == 1 {
   370  				// If the scaling needed for A in the dot product is 1, call
   371  				// Ddot to perform the dot product...
   372  				if uplo == blas.Upper {
   373  					jlen := min(j, kd)
   374  					if jlen > 0 {
   375  						sumj = bi.Ddot(jlen, ab[(j-jlen)*ldab+jlen:], kld, x[j-jlen:], 1)
   376  					}
   377  				} else {
   378  					jlen := min(n-j-1, kd)
   379  					if jlen > 0 {
   380  						sumj = bi.Ddot(jlen, ab[(j+1)*ldab+kd-1:], kld, x[j+1:], 1)
   381  					}
   382  				}
   383  			} else {
   384  				// ...otherwise, use in-line code for the dot product.
   385  				if uplo == blas.Upper {
   386  					jlen := min(j, kd)
   387  					for i := 0; i < jlen; i++ {
   388  						sumj += (ab[(j-jlen+i)*ldab+jlen-i] * uscal) * x[j-jlen+i]
   389  					}
   390  				} else {
   391  					jlen := min(n-j-1, kd)
   392  					for i := 0; i < jlen; i++ {
   393  						sumj += (ab[(j+1+i)*ldab+kd-1-i] * uscal) * x[j+i+1]
   394  					}
   395  				}
   396  			}
   397  
   398  			if uscal == tscal {
   399  				// Compute x[j] := ( x[j] - sumj ) / A[j,j]
   400  				// if 1/A[j,j] was not used to scale the dot product.
   401  				x[j] -= sumj
   402  				xj = math.Abs(x[j])
   403  				// Compute x[j] = x[j] / A[j,j], scaling if necessary.
   404  				// Note: the reference implementation skips this step for blas.Unit matrices
   405  				// when tscal is equal to 1 but it complicates the logic and only saves
   406  				// the comparison and division in the first switch-case. Not skipping it
   407  				// is also consistent with the NoTrans case above.
   408  				switch {
   409  				case tjj > smlnum:
   410  					// smlnum < abs(A[j,j]):
   411  					if tjj < 1 && xj > tjj*bignum {
   412  						// Scale x by 1/abs(x[j]).
   413  						rec := 1 / xj
   414  						bi.Dscal(n, rec, x, 1)
   415  						scale *= rec
   416  						xMax *= rec
   417  					}
   418  					x[j] /= tjjs
   419  				case tjj > 0:
   420  					// 0 < abs(A[j,j]) <= smlnum:
   421  					if xj > tjj*bignum {
   422  						// Scale x by (1/abs(x[j]))*abs(A[j,j])*bignum.
   423  						rec := (tjj * bignum) / xj
   424  						bi.Dscal(n, rec, x, 1)
   425  						scale *= rec
   426  						xMax *= rec
   427  					}
   428  					x[j] /= tjjs
   429  				default:
   430  					// A[j,j] == 0: Set x[0:n] = 0, x[j] = 1, and scale = 0, and
   431  					// compute a solution Aᵀ * x = 0.
   432  					for i := range x[:n] {
   433  						x[i] = 0
   434  					}
   435  					x[j] = 1
   436  					scale = 0
   437  					xMax = 0
   438  				}
   439  			} else {
   440  				// Compute x[j] := x[j] / A[j,j] - sumj
   441  				// if the dot product has already been divided by 1/A[j,j].
   442  				x[j] = x[j]/tjjs - sumj
   443  			}
   444  			xMax = math.Max(xMax, math.Abs(x[j]))
   445  		}
   446  		scale /= tscal
   447  	}
   448  
   449  	// Scale the column norms by 1/tscal for return.
   450  	if tscal != 1 {
   451  		bi.Dscal(n, 1/tscal, cnorm, 1)
   452  	}
   453  	return scale
   454  }