github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/testblas/ztpsv.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  	"github.com/jingcheng-WU/gonum/blas"
    13  )
    14  
    15  type Ztpsver interface {
    16  	Ztpsv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n int, ap []complex128, x []complex128, incX int)
    17  
    18  	Ztpmver
    19  }
    20  
    21  func ZtpsvTest(t *testing.T, impl Ztpsver) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    24  		for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    25  			for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
    26  				for _, n := range []int{0, 1, 2, 3, 4, 10} {
    27  					for _, incX := range []int{-11, -3, -2, -1, 1, 2, 3, 7} {
    28  						ztpsvTest(t, impl, uplo, trans, diag, n, incX, rnd)
    29  					}
    30  				}
    31  			}
    32  		}
    33  	}
    34  }
    35  
    36  // ztpsvTest tests Ztpsv by checking whether Ztpmv followed by Ztpsv
    37  // round-trip.
    38  func ztpsvTest(t *testing.T, impl Ztpsver, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, incX int, rnd *rand.Rand) {
    39  	const tol = 1e-10
    40  
    41  	// Allocate a dense-storage triangular matrix filled with NaNs that
    42  	// will be used as a for creating the actual triangular matrix in packed
    43  	// storage.
    44  	lda := n
    45  	a := makeZGeneral(nil, n, n, max(1, lda))
    46  	// Fill the referenced triangle of A with random data.
    47  	if uplo == blas.Upper {
    48  		for i := 0; i < n; i++ {
    49  			for j := i; j < n; j++ {
    50  				re := rnd.NormFloat64()
    51  				im := rnd.NormFloat64()
    52  				a[i*lda+j] = complex(re, im)
    53  			}
    54  		}
    55  	} else {
    56  		for i := 0; i < n; i++ {
    57  			for j := 0; j <= i; j++ {
    58  				re := rnd.NormFloat64()
    59  				im := rnd.NormFloat64()
    60  				a[i*lda+j] = complex(re, im)
    61  			}
    62  		}
    63  	}
    64  	if diag == blas.Unit {
    65  		// The diagonal should not be referenced by Ztpmv and Ztpsv, so
    66  		// invalidate it with NaNs.
    67  		for i := 0; i < n; i++ {
    68  			a[i*lda+i] = znan
    69  		}
    70  	}
    71  	// Create the triangular matrix in packed storage.
    72  	ap := zPack(uplo, n, a, n)
    73  	apCopy := make([]complex128, len(ap))
    74  	copy(apCopy, ap)
    75  
    76  	// Generate a random complex vector x.
    77  	xtest := make([]complex128, n)
    78  	for i := range xtest {
    79  		re := rnd.NormFloat64()
    80  		im := rnd.NormFloat64()
    81  		xtest[i] = complex(re, im)
    82  	}
    83  	x := makeZVector(xtest, incX)
    84  
    85  	// Store a copy of x as the correct result that we want.
    86  	want := make([]complex128, len(x))
    87  	copy(want, x)
    88  
    89  	// Compute A*x, denoting the result by b and storing it in x.
    90  	impl.Ztpmv(uplo, trans, diag, n, ap, x, incX)
    91  	// Solve A*x = b, that is, x = A^{-1}*b = A^{-1}*A*x.
    92  	impl.Ztpsv(uplo, trans, diag, n, ap, x, incX)
    93  	// If Ztpsv is correct, A^{-1}*A = I and x contains again its original value.
    94  
    95  	name := fmt.Sprintf("uplo=%v,trans=%v,diag=%v,n=%v,incX=%v", uplo, trans, diag, n, incX)
    96  	if !zsame(ap, apCopy) {
    97  		t.Errorf("%v: unexpected modification of ap", name)
    98  	}
    99  	if !zSameAtNonstrided(x, want, incX) {
   100  		t.Errorf("%v: unexpected modification of x\nwant %v\ngot  %v", name, want, x)
   101  	}
   102  	if !zEqualApproxAtStrided(x, want, incX, tol) {
   103  		t.Errorf("%v: unexpected result\nwant %v\ngot  %v", name, want, x)
   104  	}
   105  }