github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dlatbs.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  	"github.com/jingcheng-WU/gonum/blas"
    14  	"github.com/jingcheng-WU/gonum/blas/blas64"
    15  	"github.com/jingcheng-WU/gonum/floats"
    16  )
    17  
    18  type Dlatbser interface {
    19  	Dlatbs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n, kd int, ab []float64, ldab int, x []float64, cnorm []float64) float64
    20  }
    21  
    22  // DlatbsTest tests Dlatbs by generating a random triangular band system and
    23  // checking that a residual for the computed solution is small.
    24  func DlatbsTest(t *testing.T, impl Dlatbser) {
    25  	rnd := rand.New(rand.NewSource(1))
    26  	for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 50} {
    27  		for _, kd := range []int{0, (n + 1) / 4, (3*n - 1) / 4, (5*n + 1) / 4} {
    28  			for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    29  				for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    30  					for _, ldab := range []int{kd + 1, kd + 1 + 7} {
    31  						for _, kind := range []int{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18} {
    32  							dlatbsTest(t, impl, rnd, kind, uplo, trans, n, kd, ldab)
    33  						}
    34  					}
    35  				}
    36  			}
    37  		}
    38  	}
    39  }
    40  
    41  func dlatbsTest(t *testing.T, impl Dlatbser, rnd *rand.Rand, kind int, uplo blas.Uplo, trans blas.Transpose, n, kd, ldab int) {
    42  	const eps = 1e-15
    43  
    44  	// Allocate a triangular band matrix.
    45  	var ab []float64
    46  	if n > 0 {
    47  		ab = make([]float64, (n-1)*ldab+kd+1)
    48  	}
    49  	for i := range ab {
    50  		ab[i] = rnd.NormFloat64()
    51  	}
    52  
    53  	// Generate a triangular test matrix and the right-hand side.
    54  	diag, b := dlattb(kind, uplo, trans, n, kd, ab, ldab, rnd)
    55  
    56  	// Make a copy of AB to make sure that it is not modified in Dlatbs.
    57  	abCopy := make([]float64, len(ab))
    58  	copy(abCopy, ab)
    59  
    60  	// Allocate cnorm and fill it with impossible result to make sure that it
    61  	// _is_ updated in the first Dlatbs call below.
    62  	cnorm := make([]float64, n)
    63  	for i := range cnorm {
    64  		cnorm[i] = -1
    65  	}
    66  
    67  	// Solve the system op(A)*x = b.
    68  	x := make([]float64, n)
    69  	copy(x, b)
    70  	scale := impl.Dlatbs(uplo, trans, diag, false, n, kd, ab, ldab, x, cnorm)
    71  
    72  	name := fmt.Sprintf("kind=%v,uplo=%v,trans=%v,diag=%v,n=%v,kd=%v,ldab=%v",
    73  		kind, string(uplo), string(trans), string(diag), n, kd, ldab)
    74  
    75  	if !floats.Equal(ab, abCopy) {
    76  		t.Errorf("%v: unexpected modification of ab", name)
    77  	}
    78  	if floats.Count(func(v float64) bool { return v == -1 }, cnorm) > 0 {
    79  		t.Errorf("%v: expected modification of cnorm", name)
    80  	}
    81  
    82  	resid := dlatbsResidual(uplo, trans, diag, n, kd, ab, ldab, scale, cnorm, b, x)
    83  	if resid >= eps {
    84  		t.Errorf("%v: unexpected result when normin=false. residual=%v", name, resid)
    85  	}
    86  
    87  	// Make a copy of cnorm to check that it is _not_ modified.
    88  	cnormCopy := make([]float64, len(cnorm))
    89  	copy(cnormCopy, cnorm)
    90  	// Restore x.
    91  	copy(x, b)
    92  	// Solve the system op(A)*x = b again with normin = true.
    93  	scale = impl.Dlatbs(uplo, trans, diag, true, n, kd, ab, ldab, x, cnorm)
    94  
    95  	// Cannot test for exact equality because Dlatbs may scale cnorm by s and
    96  	// then by 1/s before return.
    97  	if !floats.EqualApprox(cnorm, cnormCopy, 1e-15) {
    98  		t.Errorf("%v: unexpected modification of cnorm", name)
    99  	}
   100  
   101  	resid = dlatbsResidual(uplo, trans, diag, n, kd, ab, ldab, scale, cnorm, b, x)
   102  	if resid >= eps {
   103  		t.Errorf("%v: unexpected result when normin=true. residual=%v", name, resid)
   104  	}
   105  }
   106  
   107  // dlatbsResidual returns the residual for the solution to a scaled triangular
   108  // system of equations  A*x = s*b  or  Aᵀ*x = s*b  when A is an n×n triangular
   109  // band matrix with kd super- or sub-diagonals. The residual is computed as
   110  //  norm( op(A)*x - scale*b ) / ( norm(op(A)) * norm(x) ).
   111  //
   112  // This function corresponds to DTBT03 in Reference LAPACK.
   113  func dlatbsResidual(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, kd int, ab []float64, ldab int, scale float64, cnorm, b, x []float64) float64 {
   114  	if n == 0 {
   115  		return 0
   116  	}
   117  
   118  	// Compute the norm of the triangular matrix A using the columns norms
   119  	// already computed by Dlatbs.
   120  	var tnorm float64
   121  	if diag == blas.NonUnit {
   122  		if uplo == blas.Upper {
   123  			for j := 0; j < n; j++ {
   124  				tnorm = math.Max(tnorm, math.Abs(ab[j*ldab])+cnorm[j])
   125  			}
   126  		} else {
   127  			for j := 0; j < n; j++ {
   128  				tnorm = math.Max(tnorm, math.Abs(ab[j*ldab+kd])+cnorm[j])
   129  			}
   130  		}
   131  	} else {
   132  		for j := 0; j < n; j++ {
   133  			tnorm = math.Max(tnorm, 1+cnorm[j])
   134  		}
   135  	}
   136  
   137  	bi := blas64.Implementation()
   138  	eps := dlamchE
   139  	smlnum := dlamchS
   140  
   141  	ix := bi.Idamax(n, x, 1)
   142  	xNorm := math.Max(1, math.Abs(x[ix]))
   143  	xScal := (1 / xNorm) / float64(kd+1)
   144  
   145  	resid := make([]float64, len(x))
   146  	copy(resid, x)
   147  	bi.Dscal(n, xScal, resid, 1)
   148  	bi.Dtbmv(uplo, trans, diag, n, kd, ab, ldab, resid, 1)
   149  	bi.Daxpy(n, -scale*xScal, b, 1, resid, 1)
   150  
   151  	ix = bi.Idamax(n, resid, 1)
   152  	residNorm := math.Abs(resid[ix])
   153  	if residNorm*smlnum <= xNorm {
   154  		if xNorm > 0 {
   155  			residNorm /= xNorm
   156  		}
   157  	} else if residNorm > 0 {
   158  		residNorm = 1 / eps
   159  	}
   160  	if residNorm*smlnum <= tnorm {
   161  		if tnorm > 0 {
   162  			residNorm /= tnorm
   163  		}
   164  	} else if residNorm > 0 {
   165  		residNorm = 1 / eps
   166  	}
   167  
   168  	return residNorm
   169  }