github.com/gopherd/gonum@v0.0.4/blas/blas64/conv_symmetric_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package blas64
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/gopherd/gonum/blas"
    12  )
    13  
    14  func newSymmetricFrom(a SymmetricCols) Symmetric {
    15  	t := Symmetric{
    16  		N:      a.N,
    17  		Stride: a.N,
    18  		Data:   make([]float64, a.N*a.N),
    19  		Uplo:   a.Uplo,
    20  	}
    21  	t.From(a)
    22  	return t
    23  }
    24  
    25  func (m Symmetric) n() int { return m.N }
    26  func (m Symmetric) at(i, j int) float64 {
    27  	if m.Uplo == blas.Lower && i < j && j < m.N {
    28  		i, j = j, i
    29  	}
    30  	if m.Uplo == blas.Upper && i > j {
    31  		i, j = j, i
    32  	}
    33  	return m.Data[i*m.Stride+j]
    34  }
    35  func (m Symmetric) uplo() blas.Uplo { return m.Uplo }
    36  
    37  func newSymmetricColsFrom(a Symmetric) SymmetricCols {
    38  	t := SymmetricCols{
    39  		N:      a.N,
    40  		Stride: a.N,
    41  		Data:   make([]float64, a.N*a.N),
    42  		Uplo:   a.Uplo,
    43  	}
    44  	t.From(a)
    45  	return t
    46  }
    47  
    48  func (m SymmetricCols) n() int { return m.N }
    49  func (m SymmetricCols) at(i, j int) float64 {
    50  	if m.Uplo == blas.Lower && i < j {
    51  		i, j = j, i
    52  	}
    53  	if m.Uplo == blas.Upper && i > j && i < m.N {
    54  		i, j = j, i
    55  	}
    56  	return m.Data[i+j*m.Stride]
    57  }
    58  func (m SymmetricCols) uplo() blas.Uplo { return m.Uplo }
    59  
    60  type symmetric interface {
    61  	n() int
    62  	at(i, j int) float64
    63  	uplo() blas.Uplo
    64  }
    65  
    66  func sameSymmetric(a, b symmetric) bool {
    67  	an := a.n()
    68  	bn := b.n()
    69  	if an != bn {
    70  		return false
    71  	}
    72  	if a.uplo() != b.uplo() {
    73  		return false
    74  	}
    75  	for i := 0; i < an; i++ {
    76  		for j := 0; j < an; j++ {
    77  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
    78  				return false
    79  			}
    80  		}
    81  	}
    82  	return true
    83  }
    84  
    85  var symmetricTests = []Symmetric{
    86  	{N: 3, Stride: 3, Data: []float64{
    87  		1, 2, 3,
    88  		4, 5, 6,
    89  		7, 8, 9,
    90  	}},
    91  	{N: 3, Stride: 5, Data: []float64{
    92  		1, 2, 3, 0, 0,
    93  		4, 5, 6, 0, 0,
    94  		7, 8, 9, 0, 0,
    95  	}},
    96  }
    97  
    98  func TestConvertSymmetric(t *testing.T) {
    99  	for _, test := range symmetricTests {
   100  		for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
   101  			test.Uplo = uplo
   102  			colmajor := newSymmetricColsFrom(test)
   103  			if !sameSymmetric(colmajor, test) {
   104  				t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   105  					colmajor, test)
   106  			}
   107  			rowmajor := newSymmetricFrom(colmajor)
   108  			if !sameSymmetric(rowmajor, test) {
   109  				t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   110  					rowmajor, test)
   111  			}
   112  		}
   113  	}
   114  }
   115  func newSymmetricBandFrom(a SymmetricBandCols) SymmetricBand {
   116  	t := SymmetricBand{
   117  		N:      a.N,
   118  		K:      a.K,
   119  		Stride: a.K + 1,
   120  		Data:   make([]float64, a.N*(a.K+1)),
   121  		Uplo:   a.Uplo,
   122  	}
   123  	for i := range t.Data {
   124  		t.Data[i] = math.NaN()
   125  	}
   126  	t.From(a)
   127  	return t
   128  }
   129  
   130  func (m SymmetricBand) n() (n int) { return m.N }
   131  func (m SymmetricBand) at(i, j int) float64 {
   132  	b := Band{
   133  		Rows: m.N, Cols: m.N,
   134  		Stride: m.Stride,
   135  		Data:   m.Data,
   136  	}
   137  	switch m.Uplo {
   138  	default:
   139  		panic("blas64: bad BLAS uplo")
   140  	case blas.Upper:
   141  		b.KU = m.K
   142  		if i > j {
   143  			i, j = j, i
   144  		}
   145  	case blas.Lower:
   146  		b.KL = m.K
   147  		if i < j {
   148  			i, j = j, i
   149  		}
   150  	}
   151  	return b.at(i, j)
   152  }
   153  func (m SymmetricBand) bandwidth() (k int) { return m.K }
   154  func (m SymmetricBand) uplo() blas.Uplo    { return m.Uplo }
   155  
   156  func newSymmetricBandColsFrom(a SymmetricBand) SymmetricBandCols {
   157  	t := SymmetricBandCols{
   158  		N:      a.N,
   159  		K:      a.K,
   160  		Stride: a.K + 1,
   161  		Data:   make([]float64, a.N*(a.K+1)),
   162  		Uplo:   a.Uplo,
   163  	}
   164  	for i := range t.Data {
   165  		t.Data[i] = math.NaN()
   166  	}
   167  	t.From(a)
   168  	return t
   169  }
   170  
   171  func (m SymmetricBandCols) n() (n int) { return m.N }
   172  func (m SymmetricBandCols) at(i, j int) float64 {
   173  	b := BandCols{
   174  		Rows: m.N, Cols: m.N,
   175  		Stride: m.Stride,
   176  		Data:   m.Data,
   177  	}
   178  	switch m.Uplo {
   179  	default:
   180  		panic("blas64: bad BLAS uplo")
   181  	case blas.Upper:
   182  		b.KU = m.K
   183  		if i > j {
   184  			i, j = j, i
   185  		}
   186  	case blas.Lower:
   187  		b.KL = m.K
   188  		if i < j {
   189  			i, j = j, i
   190  		}
   191  	}
   192  	return b.at(i, j)
   193  }
   194  func (m SymmetricBandCols) bandwidth() (k int) { return m.K }
   195  func (m SymmetricBandCols) uplo() blas.Uplo    { return m.Uplo }
   196  
   197  type symmetricBand interface {
   198  	n() (n int)
   199  	at(i, j int) float64
   200  	bandwidth() (k int)
   201  	uplo() blas.Uplo
   202  }
   203  
   204  func sameSymmetricBand(a, b symmetricBand) bool {
   205  	an := a.n()
   206  	bn := b.n()
   207  	if an != bn {
   208  		return false
   209  	}
   210  	if a.uplo() != b.uplo() {
   211  		return false
   212  	}
   213  	ak := a.bandwidth()
   214  	bk := b.bandwidth()
   215  	if ak != bk {
   216  		return false
   217  	}
   218  	for i := 0; i < an; i++ {
   219  		for j := 0; j < an; j++ {
   220  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   221  				return false
   222  			}
   223  		}
   224  	}
   225  	return true
   226  }
   227  
   228  var symmetricBandTests = []SymmetricBand{
   229  	{N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float64{
   230  		1,
   231  		2,
   232  		3,
   233  	}},
   234  	{N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float64{
   235  		1,
   236  		2,
   237  		3,
   238  	}},
   239  	{N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float64{
   240  		1, 2,
   241  		3, 4,
   242  		5, -1,
   243  	}},
   244  	{N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float64{
   245  		-1, 1,
   246  		2, 3,
   247  		4, 5,
   248  	}},
   249  	{N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float64{
   250  		1, 2, 3,
   251  		4, 5, -1,
   252  		6, -2, -3,
   253  	}},
   254  	{N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float64{
   255  		-2, -1, 1,
   256  		-3, 2, 4,
   257  		3, 5, 6,
   258  	}},
   259  
   260  	{N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float64{
   261  		1, 0, 0, 0, 0,
   262  		2, 0, 0, 0, 0,
   263  		3, 0, 0, 0, 0,
   264  	}},
   265  	{N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float64{
   266  		1, 0, 0, 0, 0,
   267  		2, 0, 0, 0, 0,
   268  		3, 0, 0, 0, 0,
   269  	}},
   270  	{N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float64{
   271  		1, 2, 0, 0, 0,
   272  		3, 4, 0, 0, 0,
   273  		5, -1, 0, 0, 0,
   274  	}},
   275  	{N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float64{
   276  		-1, 1, 0, 0, 0,
   277  		2, 3, 0, 0, 0,
   278  		4, 5, 0, 0, 0,
   279  	}},
   280  	{N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float64{
   281  		1, 2, 3, 0, 0,
   282  		4, 5, -1, 0, 0,
   283  		6, -2, -3, 0, 0,
   284  	}},
   285  	{N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float64{
   286  		-2, -1, 1, 0, 0,
   287  		-3, 2, 4, 0, 0,
   288  		3, 5, 6, 0, 0,
   289  	}},
   290  }
   291  
   292  func TestConvertSymBand(t *testing.T) {
   293  	for _, test := range symmetricBandTests {
   294  		colmajor := newSymmetricBandColsFrom(test)
   295  		if !sameSymmetricBand(colmajor, test) {
   296  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   297  				colmajor, test)
   298  		}
   299  		rowmajor := newSymmetricBandFrom(colmajor)
   300  		if !sameSymmetricBand(rowmajor, test) {
   301  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   302  				rowmajor, test)
   303  		}
   304  	}
   305  }