gonum.org/v1/gonum@v0.14.0/blas/blas32/conv_symmetric_test.go (about)

     1  // Code generated by "go generate gonum.org/v1/gonum/blas”; DO NOT EDIT.
     2  
     3  // Copyright ©2015 The Gonum Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package blas32
     8  
     9  import (
    10  	math "gonum.org/v1/gonum/internal/math32"
    11  	"testing"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  )
    15  
    16  func newSymmetricFrom(a SymmetricCols) Symmetric {
    17  	t := Symmetric{
    18  		N:      a.N,
    19  		Stride: a.N,
    20  		Data:   make([]float32, a.N*a.N),
    21  		Uplo:   a.Uplo,
    22  	}
    23  	t.From(a)
    24  	return t
    25  }
    26  
    27  func (m Symmetric) n() int { return m.N }
    28  func (m Symmetric) at(i, j int) float32 {
    29  	if m.Uplo == blas.Lower && i < j && j < m.N {
    30  		i, j = j, i
    31  	}
    32  	if m.Uplo == blas.Upper && i > j {
    33  		i, j = j, i
    34  	}
    35  	return m.Data[i*m.Stride+j]
    36  }
    37  func (m Symmetric) uplo() blas.Uplo { return m.Uplo }
    38  
    39  func newSymmetricColsFrom(a Symmetric) SymmetricCols {
    40  	t := SymmetricCols{
    41  		N:      a.N,
    42  		Stride: a.N,
    43  		Data:   make([]float32, a.N*a.N),
    44  		Uplo:   a.Uplo,
    45  	}
    46  	t.From(a)
    47  	return t
    48  }
    49  
    50  func (m SymmetricCols) n() int { return m.N }
    51  func (m SymmetricCols) at(i, j int) float32 {
    52  	if m.Uplo == blas.Lower && i < j {
    53  		i, j = j, i
    54  	}
    55  	if m.Uplo == blas.Upper && i > j && i < m.N {
    56  		i, j = j, i
    57  	}
    58  	return m.Data[i+j*m.Stride]
    59  }
    60  func (m SymmetricCols) uplo() blas.Uplo { return m.Uplo }
    61  
    62  type symmetric interface {
    63  	n() int
    64  	at(i, j int) float32
    65  	uplo() blas.Uplo
    66  }
    67  
    68  func sameSymmetric(a, b symmetric) bool {
    69  	an := a.n()
    70  	bn := b.n()
    71  	if an != bn {
    72  		return false
    73  	}
    74  	if a.uplo() != b.uplo() {
    75  		return false
    76  	}
    77  	for i := 0; i < an; i++ {
    78  		for j := 0; j < an; j++ {
    79  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
    80  				return false
    81  			}
    82  		}
    83  	}
    84  	return true
    85  }
    86  
    87  var symmetricTests = []Symmetric{
    88  	{N: 3, Stride: 3, Data: []float32{
    89  		1, 2, 3,
    90  		4, 5, 6,
    91  		7, 8, 9,
    92  	}},
    93  	{N: 3, Stride: 5, Data: []float32{
    94  		1, 2, 3, 0, 0,
    95  		4, 5, 6, 0, 0,
    96  		7, 8, 9, 0, 0,
    97  	}},
    98  }
    99  
   100  func TestConvertSymmetric(t *testing.T) {
   101  	for _, test := range symmetricTests {
   102  		for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
   103  			test.Uplo = uplo
   104  			colmajor := newSymmetricColsFrom(test)
   105  			if !sameSymmetric(colmajor, test) {
   106  				t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   107  					colmajor, test)
   108  			}
   109  			rowmajor := newSymmetricFrom(colmajor)
   110  			if !sameSymmetric(rowmajor, test) {
   111  				t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   112  					rowmajor, test)
   113  			}
   114  		}
   115  	}
   116  }
   117  func newSymmetricBandFrom(a SymmetricBandCols) SymmetricBand {
   118  	t := SymmetricBand{
   119  		N:      a.N,
   120  		K:      a.K,
   121  		Stride: a.K + 1,
   122  		Data:   make([]float32, a.N*(a.K+1)),
   123  		Uplo:   a.Uplo,
   124  	}
   125  	for i := range t.Data {
   126  		t.Data[i] = math.NaN()
   127  	}
   128  	t.From(a)
   129  	return t
   130  }
   131  
   132  func (m SymmetricBand) n() (n int) { return m.N }
   133  func (m SymmetricBand) at(i, j int) float32 {
   134  	b := Band{
   135  		Rows: m.N, Cols: m.N,
   136  		Stride: m.Stride,
   137  		Data:   m.Data,
   138  	}
   139  	switch m.Uplo {
   140  	default:
   141  		panic("blas32: bad BLAS uplo")
   142  	case blas.Upper:
   143  		b.KU = m.K
   144  		if i > j {
   145  			i, j = j, i
   146  		}
   147  	case blas.Lower:
   148  		b.KL = m.K
   149  		if i < j {
   150  			i, j = j, i
   151  		}
   152  	}
   153  	return b.at(i, j)
   154  }
   155  func (m SymmetricBand) bandwidth() (k int) { return m.K }
   156  func (m SymmetricBand) uplo() blas.Uplo    { return m.Uplo }
   157  
   158  func newSymmetricBandColsFrom(a SymmetricBand) SymmetricBandCols {
   159  	t := SymmetricBandCols{
   160  		N:      a.N,
   161  		K:      a.K,
   162  		Stride: a.K + 1,
   163  		Data:   make([]float32, a.N*(a.K+1)),
   164  		Uplo:   a.Uplo,
   165  	}
   166  	for i := range t.Data {
   167  		t.Data[i] = math.NaN()
   168  	}
   169  	t.From(a)
   170  	return t
   171  }
   172  
   173  func (m SymmetricBandCols) n() (n int) { return m.N }
   174  func (m SymmetricBandCols) at(i, j int) float32 {
   175  	b := BandCols{
   176  		Rows: m.N, Cols: m.N,
   177  		Stride: m.Stride,
   178  		Data:   m.Data,
   179  	}
   180  	switch m.Uplo {
   181  	default:
   182  		panic("blas32: bad BLAS uplo")
   183  	case blas.Upper:
   184  		b.KU = m.K
   185  		if i > j {
   186  			i, j = j, i
   187  		}
   188  	case blas.Lower:
   189  		b.KL = m.K
   190  		if i < j {
   191  			i, j = j, i
   192  		}
   193  	}
   194  	return b.at(i, j)
   195  }
   196  func (m SymmetricBandCols) bandwidth() (k int) { return m.K }
   197  func (m SymmetricBandCols) uplo() blas.Uplo    { return m.Uplo }
   198  
   199  type symmetricBand interface {
   200  	n() (n int)
   201  	at(i, j int) float32
   202  	bandwidth() (k int)
   203  	uplo() blas.Uplo
   204  }
   205  
   206  func sameSymmetricBand(a, b symmetricBand) bool {
   207  	an := a.n()
   208  	bn := b.n()
   209  	if an != bn {
   210  		return false
   211  	}
   212  	if a.uplo() != b.uplo() {
   213  		return false
   214  	}
   215  	ak := a.bandwidth()
   216  	bk := b.bandwidth()
   217  	if ak != bk {
   218  		return false
   219  	}
   220  	for i := 0; i < an; i++ {
   221  		for j := 0; j < an; j++ {
   222  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   223  				return false
   224  			}
   225  		}
   226  	}
   227  	return true
   228  }
   229  
   230  var symmetricBandTests = []SymmetricBand{
   231  	{N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float32{
   232  		1,
   233  		2,
   234  		3,
   235  	}},
   236  	{N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float32{
   237  		1,
   238  		2,
   239  		3,
   240  	}},
   241  	{N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float32{
   242  		1, 2,
   243  		3, 4,
   244  		5, -1,
   245  	}},
   246  	{N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float32{
   247  		-1, 1,
   248  		2, 3,
   249  		4, 5,
   250  	}},
   251  	{N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float32{
   252  		1, 2, 3,
   253  		4, 5, -1,
   254  		6, -2, -3,
   255  	}},
   256  	{N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float32{
   257  		-2, -1, 1,
   258  		-3, 2, 4,
   259  		3, 5, 6,
   260  	}},
   261  
   262  	{N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float32{
   263  		1, 0, 0, 0, 0,
   264  		2, 0, 0, 0, 0,
   265  		3, 0, 0, 0, 0,
   266  	}},
   267  	{N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float32{
   268  		1, 0, 0, 0, 0,
   269  		2, 0, 0, 0, 0,
   270  		3, 0, 0, 0, 0,
   271  	}},
   272  	{N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float32{
   273  		1, 2, 0, 0, 0,
   274  		3, 4, 0, 0, 0,
   275  		5, -1, 0, 0, 0,
   276  	}},
   277  	{N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float32{
   278  		-1, 1, 0, 0, 0,
   279  		2, 3, 0, 0, 0,
   280  		4, 5, 0, 0, 0,
   281  	}},
   282  	{N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float32{
   283  		1, 2, 3, 0, 0,
   284  		4, 5, -1, 0, 0,
   285  		6, -2, -3, 0, 0,
   286  	}},
   287  	{N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float32{
   288  		-2, -1, 1, 0, 0,
   289  		-3, 2, 4, 0, 0,
   290  		3, 5, 6, 0, 0,
   291  	}},
   292  }
   293  
   294  func TestConvertSymBand(t *testing.T) {
   295  	for _, test := range symmetricBandTests {
   296  		colmajor := newSymmetricBandColsFrom(test)
   297  		if !sameSymmetricBand(colmajor, test) {
   298  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   299  				colmajor, test)
   300  		}
   301  		rowmajor := newSymmetricBandFrom(colmajor)
   302  		if !sameSymmetricBand(rowmajor, test) {
   303  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   304  				rowmajor, test)
   305  		}
   306  	}
   307  }