gonum.org/v1/gonum@v0.14.0/blas/testblas/common_test.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"math"
     9  	"math/cmplx"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/floats"
    16  )
    17  
    18  func TestFlattenBanded(t *testing.T) {
    19  	for i, test := range []struct {
    20  		dense     [][]float64
    21  		ku        int
    22  		kl        int
    23  		condensed [][]float64
    24  	}{
    25  		{
    26  			dense:     [][]float64{{3}},
    27  			ku:        0,
    28  			kl:        0,
    29  			condensed: [][]float64{{3}},
    30  		},
    31  		{
    32  			dense: [][]float64{
    33  				{3, 4, 0},
    34  			},
    35  			ku: 1,
    36  			kl: 0,
    37  			condensed: [][]float64{
    38  				{3, 4},
    39  			},
    40  		},
    41  		{
    42  			dense: [][]float64{
    43  				{3, 4, 0, 0, 0},
    44  			},
    45  			ku: 1,
    46  			kl: 0,
    47  			condensed: [][]float64{
    48  				{3, 4},
    49  			},
    50  		},
    51  		{
    52  			dense: [][]float64{
    53  				{3, 4, 0},
    54  				{0, 5, 8},
    55  				{0, 0, 2},
    56  				{0, 0, 0},
    57  				{0, 0, 0},
    58  			},
    59  			ku: 1,
    60  			kl: 0,
    61  			condensed: [][]float64{
    62  				{3, 4},
    63  				{5, 8},
    64  				{2, math.NaN()},
    65  				{math.NaN(), math.NaN()},
    66  				{math.NaN(), math.NaN()},
    67  			},
    68  		},
    69  		{
    70  			dense: [][]float64{
    71  				{3, 4, 6},
    72  				{0, 5, 8},
    73  				{0, 0, 2},
    74  				{0, 0, 0},
    75  				{0, 0, 0},
    76  			},
    77  			ku: 2,
    78  			kl: 0,
    79  			condensed: [][]float64{
    80  				{3, 4, 6},
    81  				{5, 8, math.NaN()},
    82  				{2, math.NaN(), math.NaN()},
    83  				{math.NaN(), math.NaN(), math.NaN()},
    84  				{math.NaN(), math.NaN(), math.NaN()},
    85  			},
    86  		},
    87  		{
    88  			dense: [][]float64{
    89  				{3, 4, 6},
    90  				{1, 5, 8},
    91  				{0, 6, 2},
    92  				{0, 0, 7},
    93  				{0, 0, 0},
    94  			},
    95  			ku: 2,
    96  			kl: 1,
    97  			condensed: [][]float64{
    98  				{math.NaN(), 3, 4, 6},
    99  				{1, 5, 8, math.NaN()},
   100  				{6, 2, math.NaN(), math.NaN()},
   101  				{7, math.NaN(), math.NaN(), math.NaN()},
   102  				{math.NaN(), math.NaN(), math.NaN(), math.NaN()},
   103  			},
   104  		},
   105  		{
   106  			dense: [][]float64{
   107  				{1, 2, 0},
   108  				{3, 4, 5},
   109  				{6, 7, 8},
   110  				{0, 9, 10},
   111  				{0, 0, 11},
   112  			},
   113  			ku: 1,
   114  			kl: 2,
   115  			condensed: [][]float64{
   116  				{math.NaN(), math.NaN(), 1, 2},
   117  				{math.NaN(), 3, 4, 5},
   118  				{6, 7, 8, math.NaN()},
   119  				{9, 10, math.NaN(), math.NaN()},
   120  				{11, math.NaN(), math.NaN(), math.NaN()},
   121  			},
   122  		},
   123  		{
   124  			dense: [][]float64{
   125  				{1, 0, 0},
   126  				{3, 4, 0},
   127  				{6, 7, 8},
   128  				{0, 9, 10},
   129  				{0, 0, 11},
   130  			},
   131  			ku: 0,
   132  			kl: 2,
   133  			condensed: [][]float64{
   134  				{math.NaN(), math.NaN(), 1},
   135  				{math.NaN(), 3, 4},
   136  				{6, 7, 8},
   137  				{9, 10, math.NaN()},
   138  				{11, math.NaN(), math.NaN()},
   139  			},
   140  		},
   141  		{
   142  			dense: [][]float64{
   143  				{1, 0, 0, 0, 0},
   144  				{3, 4, 0, 0, 0},
   145  				{1, 3, 5, 0, 0},
   146  			},
   147  			ku: 0,
   148  			kl: 2,
   149  			condensed: [][]float64{
   150  				{math.NaN(), math.NaN(), 1},
   151  				{math.NaN(), 3, 4},
   152  				{1, 3, 5},
   153  			},
   154  		},
   155  	} {
   156  		condensed := flattenBanded(test.dense, test.ku, test.kl)
   157  		correct := flatten(test.condensed)
   158  		if !floats.Same(condensed, correct) {
   159  			t.Errorf("Case %v mismatch. Want %v, got %v.", i, correct, condensed)
   160  		}
   161  	}
   162  }
   163  
   164  func TestFlattenTriangular(t *testing.T) {
   165  	for i, test := range []struct {
   166  		a   [][]float64
   167  		ans []float64
   168  		ul  blas.Uplo
   169  	}{
   170  		{
   171  			a: [][]float64{
   172  				{1, 2, 3},
   173  				{0, 4, 5},
   174  				{0, 0, 6},
   175  			},
   176  			ul:  blas.Upper,
   177  			ans: []float64{1, 2, 3, 4, 5, 6},
   178  		},
   179  		{
   180  			a: [][]float64{
   181  				{1, 0, 0},
   182  				{2, 3, 0},
   183  				{4, 5, 6},
   184  			},
   185  			ul:  blas.Lower,
   186  			ans: []float64{1, 2, 3, 4, 5, 6},
   187  		},
   188  	} {
   189  		a := flattenTriangular(test.a, test.ul)
   190  		if !floats.Equal(a, test.ans) {
   191  			t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
   192  		}
   193  	}
   194  }
   195  
   196  func TestPackUnpackAsHermitian(t *testing.T) {
   197  	rnd := rand.New(rand.NewSource(1))
   198  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
   199  		for _, n := range []int{1, 2, 5, 50} {
   200  			for _, lda := range []int{max(1, n), n + 11} {
   201  				a := makeZGeneral(nil, n, n, lda)
   202  				for i := 0; i < n; i++ {
   203  					for j := i; j < n; j++ {
   204  						a[i*lda+j] = complex(rnd.NormFloat64(), rnd.NormFloat64())
   205  						if i != j {
   206  							a[j*lda+i] = cmplx.Conj(a[i*lda+j])
   207  						}
   208  					}
   209  				}
   210  				aCopy := make([]complex128, len(a))
   211  				copy(aCopy, a)
   212  
   213  				ap := zPack(uplo, n, a, lda)
   214  				if !zsame(a, aCopy) {
   215  					t.Errorf("Case uplo=%v,n=%v,lda=%v: zPack modified a", uplo, n, lda)
   216  				}
   217  
   218  				apCopy := make([]complex128, len(ap))
   219  				copy(apCopy, ap)
   220  
   221  				art := zUnpackAsHermitian(uplo, n, ap)
   222  				if !zsame(ap, apCopy) {
   223  					t.Errorf("Case uplo=%v,n=%v,lda=%v: zUnpackAsHermitian modified ap", uplo, n, lda)
   224  				}
   225  
   226  				// Copy the round-tripped A into a matrix with the same stride
   227  				// as the original.
   228  				got := makeZGeneral(nil, n, n, lda)
   229  				for i := 0; i < n; i++ {
   230  					copy(got[i*lda:i*lda+n], art[i*n:i*n+n])
   231  				}
   232  				if !zsame(got, a) {
   233  					t.Errorf("Case uplo=%v,n=%v,lda=%v: zPack and zUnpackAsHermitian do not roundtrip", uplo, n, lda)
   234  				}
   235  			}
   236  		}
   237  	}
   238  }