gonum.org/v1/gonum@v0.14.0/blas/testblas/dtbmv.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"testing"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  )
    12  
    13  type Dtbmver interface {
    14  	Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
    15  }
    16  
    17  func DtbmvTest(t *testing.T, blasser Dtbmver) {
    18  	for i, test := range []struct {
    19  		ul  blas.Uplo
    20  		tA  blas.Transpose
    21  		d   blas.Diag
    22  		n   int
    23  		k   int
    24  		a   [][]float64
    25  		x   []float64
    26  		ans []float64
    27  	}{
    28  		{
    29  			ul: blas.Upper,
    30  			tA: blas.NoTrans,
    31  			d:  blas.Unit,
    32  			n:  3,
    33  			k:  1,
    34  			a: [][]float64{
    35  				{1, 2, 0},
    36  				{0, 1, 4},
    37  				{0, 0, 1},
    38  			},
    39  			x:   []float64{2, 3, 4},
    40  			ans: []float64{8, 19, 4},
    41  		},
    42  		{
    43  			ul: blas.Upper,
    44  			tA: blas.NoTrans,
    45  			d:  blas.NonUnit,
    46  			n:  5,
    47  			k:  1,
    48  			a: [][]float64{
    49  				{1, 3, 0, 0, 0},
    50  				{0, 6, 7, 0, 0},
    51  				{0, 0, 2, 1, 0},
    52  				{0, 0, 0, 12, 3},
    53  				{0, 0, 0, 0, -1},
    54  			},
    55  			x:   []float64{1, 2, 3, 4, 5},
    56  			ans: []float64{7, 33, 10, 63, -5},
    57  		},
    58  		{
    59  			ul: blas.Lower,
    60  			tA: blas.NoTrans,
    61  			d:  blas.NonUnit,
    62  			n:  5,
    63  			k:  1,
    64  			a: [][]float64{
    65  				{7, 0, 0, 0, 0},
    66  				{3, 6, 0, 0, 0},
    67  				{0, 7, 2, 0, 0},
    68  				{0, 0, 1, 12, 0},
    69  				{0, 0, 0, 3, -1},
    70  			},
    71  			x:   []float64{1, 2, 3, 4, 5},
    72  			ans: []float64{7, 15, 20, 51, 7},
    73  		},
    74  		{
    75  			ul: blas.Upper,
    76  			tA: blas.Trans,
    77  			d:  blas.NonUnit,
    78  			n:  5,
    79  			k:  2,
    80  			a: [][]float64{
    81  				{7, 3, 9, 0, 0},
    82  				{0, 6, 7, 10, 0},
    83  				{0, 0, 2, 1, 11},
    84  				{0, 0, 0, 12, 3},
    85  				{0, 0, 0, 0, -1},
    86  			},
    87  			x:   []float64{1, 2, 3, 4, 5},
    88  			ans: []float64{7, 15, 29, 71, 40},
    89  		},
    90  		{
    91  			ul: blas.Lower,
    92  			tA: blas.Trans,
    93  			d:  blas.NonUnit,
    94  			n:  5,
    95  			k:  2,
    96  			a: [][]float64{
    97  				{7, 0, 0, 0, 0},
    98  				{3, 6, 0, 0, 0},
    99  				{9, 7, 2, 0, 0},
   100  				{0, 10, 1, 12, 0},
   101  				{0, 0, 11, 3, -1},
   102  			},
   103  			x:   []float64{1, 2, 3, 4, 5},
   104  			ans: []float64{40, 73, 65, 63, -5},
   105  		},
   106  	} {
   107  		extra := 0
   108  		var aFlat []float64
   109  		if test.ul == blas.Upper {
   110  			aFlat = flattenBanded(test.a, test.k, 0)
   111  		} else {
   112  			aFlat = flattenBanded(test.a, 0, test.k)
   113  		}
   114  		incTest := func(incX, extra int) {
   115  			xnew := makeIncremented(test.x, incX, extra)
   116  			ans := makeIncremented(test.ans, incX, extra)
   117  			lda := test.k + 1
   118  			blasser.Dtbmv(test.ul, test.tA, test.d, test.n, test.k, aFlat, lda, xnew, incX)
   119  			if !dSliceTolEqual(ans, xnew) {
   120  				t.Errorf("Case %v, Inc %v: Want %v, got %v", i, incX, ans, xnew)
   121  			}
   122  		}
   123  		incTest(1, extra)
   124  		incTest(3, extra)
   125  		incTest(-2, extra)
   126  	}
   127  }