github.com/gopherd/gonum@v0.0.4/blas/testblas/dtbsv.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  )
    12  
    13  type Dtbsver interface {
    14  	Dtbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
    15  	Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
    16  }
    17  
    18  func DtbsvTest(t *testing.T, blasser Dtbsver) {
    19  	for i, test := range []struct {
    20  		ul   blas.Uplo
    21  		tA   blas.Transpose
    22  		d    blas.Diag
    23  		n, k int
    24  		a    [][]float64
    25  		x    []float64
    26  		incX int
    27  		ans  []float64
    28  	}{
    29  		{
    30  			ul: blas.Upper,
    31  			tA: blas.NoTrans,
    32  			d:  blas.NonUnit,
    33  			n:  5,
    34  			k:  1,
    35  			a: [][]float64{
    36  				{1, 3, 0, 0, 0},
    37  				{0, 6, 7, 0, 0},
    38  				{0, 0, 2, 1, 0},
    39  				{0, 0, 0, 12, 3},
    40  				{0, 0, 0, 0, -1},
    41  			},
    42  			x:    []float64{1, 2, 3, 4, 5},
    43  			incX: 1,
    44  			ans:  []float64{2.479166666666667, -0.493055555555556, 0.708333333333333, 1.583333333333333, -5.000000000000000},
    45  		},
    46  		{
    47  			ul: blas.Upper,
    48  			tA: blas.NoTrans,
    49  			d:  blas.NonUnit,
    50  			n:  5,
    51  			k:  2,
    52  			a: [][]float64{
    53  				{1, 3, 5, 0, 0},
    54  				{0, 6, 7, 5, 0},
    55  				{0, 0, 2, 1, 5},
    56  				{0, 0, 0, 12, 3},
    57  				{0, 0, 0, 0, -1},
    58  			},
    59  			x:    []float64{1, 2, 3, 4, 5},
    60  			incX: 1,
    61  			ans:  []float64{-15.854166666666664, -16.395833333333336, 13.208333333333334, 1.583333333333333, -5.000000000000000},
    62  		},
    63  		{
    64  			ul: blas.Upper,
    65  			tA: blas.NoTrans,
    66  			d:  blas.NonUnit,
    67  			n:  5,
    68  			k:  1,
    69  			a: [][]float64{
    70  				{1, 3, 0, 0, 0},
    71  				{0, 6, 7, 0, 0},
    72  				{0, 0, 2, 1, 0},
    73  				{0, 0, 0, 12, 3},
    74  				{0, 0, 0, 0, -1},
    75  			},
    76  			x:    []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
    77  			incX: 2,
    78  			ans:  []float64{2.479166666666667, -101, -0.493055555555556, -201, 0.708333333333333, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
    79  		},
    80  		{
    81  			ul: blas.Upper,
    82  			tA: blas.NoTrans,
    83  			d:  blas.NonUnit,
    84  			n:  5,
    85  			k:  2,
    86  			a: [][]float64{
    87  				{1, 3, 5, 0, 0},
    88  				{0, 6, 7, 5, 0},
    89  				{0, 0, 2, 1, 5},
    90  				{0, 0, 0, 12, 3},
    91  				{0, 0, 0, 0, -1},
    92  			},
    93  			x:    []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
    94  			incX: 2,
    95  			ans:  []float64{-15.854166666666664, -101, -16.395833333333336, -201, 13.208333333333334, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
    96  		},
    97  		{
    98  			ul: blas.Lower,
    99  			tA: blas.NoTrans,
   100  			d:  blas.NonUnit,
   101  			n:  5,
   102  			k:  2,
   103  			a: [][]float64{
   104  				{1, 0, 0, 0, 0},
   105  				{3, 6, 0, 0, 0},
   106  				{5, 7, 2, 0, 0},
   107  				{0, 5, 1, 12, 0},
   108  				{0, 0, 5, 3, -1},
   109  			},
   110  			x:    []float64{1, 2, 3, 4, 5},
   111  			incX: 1,
   112  			ans:  []float64{1, -0.166666666666667, -0.416666666666667, 0.437500000000000, -5.770833333333334},
   113  		},
   114  		{
   115  			ul: blas.Lower,
   116  			tA: blas.NoTrans,
   117  			d:  blas.NonUnit,
   118  			n:  5,
   119  			k:  2,
   120  			a: [][]float64{
   121  				{1, 0, 0, 0, 0},
   122  				{3, 6, 0, 0, 0},
   123  				{5, 7, 2, 0, 0},
   124  				{0, 5, 1, 12, 0},
   125  				{0, 0, 5, 3, -1},
   126  			},
   127  			x:    []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
   128  			incX: 2,
   129  			ans:  []float64{1, -101, -0.166666666666667, -201, -0.416666666666667, -301, 0.437500000000000, -401, -5.770833333333334, -501, -601, -701},
   130  		},
   131  		{
   132  			ul: blas.Upper,
   133  			tA: blas.Trans,
   134  			d:  blas.NonUnit,
   135  			n:  5,
   136  			k:  2,
   137  			a: [][]float64{
   138  				{1, 3, 5, 0, 0},
   139  				{0, 6, 7, 5, 0},
   140  				{0, 0, 2, 1, 5},
   141  				{0, 0, 0, 12, 3},
   142  				{0, 0, 0, 0, -1},
   143  			},
   144  			x:    []float64{1, 2, 3, 4, 5},
   145  			incX: 1,
   146  			ans:  []float64{1, -0.166666666666667, -0.416666666666667, 0.437500000000000, -5.770833333333334},
   147  		},
   148  		{
   149  			ul: blas.Upper,
   150  			tA: blas.Trans,
   151  			d:  blas.NonUnit,
   152  			n:  5,
   153  			k:  2,
   154  			a: [][]float64{
   155  				{1, 3, 5, 0, 0},
   156  				{0, 6, 7, 5, 0},
   157  				{0, 0, 2, 1, 5},
   158  				{0, 0, 0, 12, 3},
   159  				{0, 0, 0, 0, -1},
   160  			},
   161  			x:    []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
   162  			incX: 2,
   163  			ans:  []float64{1, -101, -0.166666666666667, -201, -0.416666666666667, -301, 0.437500000000000, -401, -5.770833333333334, -501, -601, -701},
   164  		},
   165  		{
   166  			ul: blas.Lower,
   167  			tA: blas.Trans,
   168  			d:  blas.NonUnit,
   169  			n:  5,
   170  			k:  2,
   171  			a: [][]float64{
   172  				{1, 0, 0, 0, 0},
   173  				{3, 6, 0, 0, 0},
   174  				{5, 7, 2, 0, 0},
   175  				{0, 5, 1, 12, 0},
   176  				{0, 0, 5, 3, -1},
   177  			},
   178  			x:    []float64{1, 2, 3, 4, 5},
   179  			incX: 1,
   180  			ans:  []float64{-15.854166666666664, -16.395833333333336, 13.208333333333334, 1.583333333333333, -5.000000000000000},
   181  		},
   182  		{
   183  			ul: blas.Lower,
   184  			tA: blas.Trans,
   185  			d:  blas.NonUnit,
   186  			n:  5,
   187  			k:  2,
   188  			a: [][]float64{
   189  				{1, 0, 0, 0, 0},
   190  				{3, 6, 0, 0, 0},
   191  				{5, 7, 2, 0, 0},
   192  				{0, 5, 1, 12, 0},
   193  				{0, 0, 5, 3, -1},
   194  			},
   195  			x:    []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
   196  			incX: 2,
   197  			ans:  []float64{-15.854166666666664, -101, -16.395833333333336, -201, 13.208333333333334, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
   198  		},
   199  	} {
   200  		var aFlat []float64
   201  		if test.ul == blas.Upper {
   202  			aFlat = flattenBanded(test.a, test.k, 0)
   203  		} else {
   204  			aFlat = flattenBanded(test.a, 0, test.k)
   205  		}
   206  		xCopy := sliceCopy(test.x)
   207  		// TODO: Have tests where the banded matrix is constructed explicitly
   208  		// to allow testing for lda =! k+1
   209  		blasser.Dtbsv(test.ul, test.tA, test.d, test.n, test.k, aFlat, test.k+1, xCopy, test.incX)
   210  		if !dSliceTolEqual(test.ans, xCopy) {
   211  			t.Errorf("Case %v: Want %v, got %v", i, test.ans, xCopy)
   212  		}
   213  	}
   214  
   215  	/*
   216  		// TODO: Uncomment when Dtrsv is fixed
   217  		// Compare with dense for larger matrices
   218  		for _, ul := range [...]blas.Uplo{blas.Upper, blas.Lower} {
   219  			for _, tA := range [...]blas.Transpose{blas.NoTrans, blas.Trans} {
   220  				for _, n := range [...]int{7, 8, 11} {
   221  					for _, d := range [...]blas.Diag{blas.NonUnit, blas.Unit} {
   222  						for _, k := range [...]int{0, 1, 3} {
   223  							for _, incX := range [...]int{1, 3} {
   224  								a := make([][]float64, n)
   225  								for i := range a {
   226  									a[i] = make([]float64, n)
   227  									for j := range a[i] {
   228  										a[i][j] = rand.Float64()
   229  									}
   230  								}
   231  								x := make([]float64, n)
   232  								for i := range x {
   233  									x[i] = rand.Float64()
   234  								}
   235  								extra := 3
   236  								xinc := makeIncremented(x, incX, extra)
   237  								bandX := sliceCopy(xinc)
   238  								var aFlatBand []float64
   239  								if ul == blas.Upper {
   240  									aFlatBand = flattenBanded(a, k, 0)
   241  								} else {
   242  									aFlatBand = flattenBanded(a, 0, k)
   243  								}
   244  								blasser.Dtbsv(ul, tA, d, n, k, aFlatBand, k+1, bandX, incX)
   245  
   246  								aFlatDense := flatten(a)
   247  								denseX := sliceCopy(xinc)
   248  								blasser.Dtrsv(ul, tA, d, n, aFlatDense, n, denseX, incX)
   249  								if !dSliceTolEqual(denseX, bandX) {
   250  									t.Errorf("Case %v: dense banded mismatch")
   251  								}
   252  							}
   253  						}
   254  					}
   255  				}
   256  			}
   257  		}
   258  	*/
   259  }