github.com/gopherd/gonum@v0.0.4/blas/testblas/ztrsv.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"math/rand"
    12  	"github.com/gopherd/gonum/blas"
    13  )
    14  
    15  type Ztrsver interface {
    16  	Ztrsv(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n int, a []complex128, lda int, x []complex128, incX int)
    17  
    18  	Ztrmver
    19  }
    20  
    21  func ZtrsvTest(t *testing.T, impl Ztrsver) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    24  		for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    25  			for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
    26  				for _, n := range []int{0, 1, 2, 3, 4, 10} {
    27  					for _, lda := range []int{max(1, n), n + 11} {
    28  						for _, incX := range []int{-11, -3, -2, -1, 1, 2, 3, 7} {
    29  							ztrsvTest(t, impl, uplo, trans, diag, n, lda, incX, rnd)
    30  						}
    31  					}
    32  				}
    33  			}
    34  		}
    35  	}
    36  }
    37  
    38  // ztrsvTest tests Ztrsv by checking whether Ztrmv followed by Ztrsv
    39  // round-trip.
    40  func ztrsvTest(t *testing.T, impl Ztrsver, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, n, lda, incX int, rnd *rand.Rand) {
    41  	const tol = 1e-10
    42  
    43  	// Allocate a dense-storage triangular matrix A filled with NaNs.
    44  	a := makeZGeneral(nil, n, n, lda)
    45  	// Fill the referenced triangle of A with random data.
    46  	if uplo == blas.Upper {
    47  		for i := 0; i < n; i++ {
    48  			for j := i; j < n; j++ {
    49  				re := rnd.NormFloat64()
    50  				im := rnd.NormFloat64()
    51  				a[i*lda+j] = complex(re, im)
    52  			}
    53  		}
    54  	} else {
    55  		for i := 0; i < n; i++ {
    56  			for j := 0; j <= i; j++ {
    57  				re := rnd.NormFloat64()
    58  				im := rnd.NormFloat64()
    59  				a[i*lda+j] = complex(re, im)
    60  			}
    61  		}
    62  	}
    63  	if diag == blas.Unit {
    64  		// The diagonal should not be referenced by Ztrmv and Ztrsv, so
    65  		// invalidate it with NaNs.
    66  		for i := 0; i < n; i++ {
    67  			a[i*lda+i] = znan
    68  		}
    69  	}
    70  	aCopy := make([]complex128, len(a))
    71  	copy(aCopy, a)
    72  
    73  	// Generate a random complex vector x.
    74  	xtest := make([]complex128, n)
    75  	for i := range xtest {
    76  		re := rnd.NormFloat64()
    77  		im := rnd.NormFloat64()
    78  		xtest[i] = complex(re, im)
    79  	}
    80  	x := makeZVector(xtest, incX)
    81  
    82  	// Store a copy of x as the correct result that we want.
    83  	want := make([]complex128, len(x))
    84  	copy(want, x)
    85  
    86  	// Compute A*x, denoting the result by b and storing it in x.
    87  	impl.Ztrmv(uplo, trans, diag, n, a, lda, x, incX)
    88  	// Solve A*x = b, that is, x = A^{-1}*b = A^{-1}*A*x.
    89  	impl.Ztrsv(uplo, trans, diag, n, a, lda, x, incX)
    90  	// If Ztrsv is correct, A^{-1}*A = I and x contains again its original value.
    91  
    92  	name := fmt.Sprintf("uplo=%v,trans=%v,diag=%v,n=%v,lda=%v,incX=%v", uplo, trans, diag, n, lda, incX)
    93  	if !zsame(a, aCopy) {
    94  		t.Errorf("%v: unexpected modification of A", name)
    95  	}
    96  	if !zSameAtNonstrided(x, want, incX) {
    97  		t.Errorf("%v: unexpected modification of x\nwant %v\ngot  %v", name, want, x)
    98  	}
    99  	if !zEqualApproxAtStrided(x, want, incX, tol) {
   100  		t.Errorf("%v: unexpected result\nwant %v\ngot  %v", name, want, x)
   101  	}
   102  }