github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/blas64/conv_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package blas64
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"github.com/jingcheng-WU/gonum/blas"
    12  )
    13  
    14  func newGeneralFrom(a GeneralCols) General {
    15  	t := General{
    16  		Rows:   a.Rows,
    17  		Cols:   a.Cols,
    18  		Stride: a.Cols,
    19  		Data:   make([]float64, a.Rows*a.Cols),
    20  	}
    21  	t.From(a)
    22  	return t
    23  }
    24  
    25  func (m General) dims() (r, c int)    { return m.Rows, m.Cols }
    26  func (m General) at(i, j int) float64 { return m.Data[i*m.Stride+j] }
    27  
    28  func newGeneralColsFrom(a General) GeneralCols {
    29  	t := GeneralCols{
    30  		Rows:   a.Rows,
    31  		Cols:   a.Cols,
    32  		Stride: a.Rows,
    33  		Data:   make([]float64, a.Rows*a.Cols),
    34  	}
    35  	t.From(a)
    36  	return t
    37  }
    38  
    39  func (m GeneralCols) dims() (r, c int)    { return m.Rows, m.Cols }
    40  func (m GeneralCols) at(i, j int) float64 { return m.Data[i+j*m.Stride] }
    41  
    42  type general interface {
    43  	dims() (r, c int)
    44  	at(i, j int) float64
    45  }
    46  
    47  func sameGeneral(a, b general) bool {
    48  	ar, ac := a.dims()
    49  	br, bc := b.dims()
    50  	if ar != br || ac != bc {
    51  		return false
    52  	}
    53  	for i := 0; i < ar; i++ {
    54  		for j := 0; j < ac; j++ {
    55  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
    56  				return false
    57  			}
    58  		}
    59  	}
    60  	return true
    61  }
    62  
    63  var generalTests = []General{
    64  	{Rows: 2, Cols: 3, Stride: 3, Data: []float64{
    65  		1, 2, 3,
    66  		4, 5, 6,
    67  	}},
    68  	{Rows: 3, Cols: 2, Stride: 2, Data: []float64{
    69  		1, 2,
    70  		3, 4,
    71  		5, 6,
    72  	}},
    73  	{Rows: 3, Cols: 3, Stride: 3, Data: []float64{
    74  		1, 2, 3,
    75  		4, 5, 6,
    76  		7, 8, 9,
    77  	}},
    78  	{Rows: 2, Cols: 3, Stride: 5, Data: []float64{
    79  		1, 2, 3, 0, 0,
    80  		4, 5, 6, 0, 0,
    81  	}},
    82  	{Rows: 3, Cols: 2, Stride: 5, Data: []float64{
    83  		1, 2, 0, 0, 0,
    84  		3, 4, 0, 0, 0,
    85  		5, 6, 0, 0, 0,
    86  	}},
    87  	{Rows: 3, Cols: 3, Stride: 5, Data: []float64{
    88  		1, 2, 3, 0, 0,
    89  		4, 5, 6, 0, 0,
    90  		7, 8, 9, 0, 0,
    91  	}},
    92  }
    93  
    94  func TestConvertGeneral(t *testing.T) {
    95  	for _, test := range generalTests {
    96  		colmajor := newGeneralColsFrom(test)
    97  		if !sameGeneral(colmajor, test) {
    98  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
    99  				colmajor, test)
   100  		}
   101  		rowmajor := newGeneralFrom(colmajor)
   102  		if !sameGeneral(rowmajor, test) {
   103  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   104  				rowmajor, test)
   105  		}
   106  	}
   107  }
   108  
   109  func newTriangularFrom(a TriangularCols) Triangular {
   110  	t := Triangular{
   111  		N:      a.N,
   112  		Stride: a.N,
   113  		Data:   make([]float64, a.N*a.N),
   114  		Diag:   a.Diag,
   115  		Uplo:   a.Uplo,
   116  	}
   117  	t.From(a)
   118  	return t
   119  }
   120  
   121  func (m Triangular) n() int { return m.N }
   122  func (m Triangular) at(i, j int) float64 {
   123  	if m.Diag == blas.Unit && i == j {
   124  		return 1
   125  	}
   126  	if m.Uplo == blas.Lower && i < j && j < m.N {
   127  		return 0
   128  	}
   129  	if m.Uplo == blas.Upper && i > j {
   130  		return 0
   131  	}
   132  	return m.Data[i*m.Stride+j]
   133  }
   134  func (m Triangular) uplo() blas.Uplo { return m.Uplo }
   135  func (m Triangular) diag() blas.Diag { return m.Diag }
   136  
   137  func newTriangularColsFrom(a Triangular) TriangularCols {
   138  	t := TriangularCols{
   139  		N:      a.N,
   140  		Stride: a.N,
   141  		Data:   make([]float64, a.N*a.N),
   142  		Diag:   a.Diag,
   143  		Uplo:   a.Uplo,
   144  	}
   145  	t.From(a)
   146  	return t
   147  }
   148  
   149  func (m TriangularCols) n() int { return m.N }
   150  func (m TriangularCols) at(i, j int) float64 {
   151  	if m.Diag == blas.Unit && i == j {
   152  		return 1
   153  	}
   154  	if m.Uplo == blas.Lower && i < j {
   155  		return 0
   156  	}
   157  	if m.Uplo == blas.Upper && i > j && i < m.N {
   158  		return 0
   159  	}
   160  	return m.Data[i+j*m.Stride]
   161  }
   162  func (m TriangularCols) uplo() blas.Uplo { return m.Uplo }
   163  func (m TriangularCols) diag() blas.Diag { return m.Diag }
   164  
   165  type triangular interface {
   166  	n() int
   167  	at(i, j int) float64
   168  	uplo() blas.Uplo
   169  	diag() blas.Diag
   170  }
   171  
   172  func sameTriangular(a, b triangular) bool {
   173  	an := a.n()
   174  	bn := b.n()
   175  	if an != bn {
   176  		return false
   177  	}
   178  	for i := 0; i < an; i++ {
   179  		for j := 0; j < an; j++ {
   180  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   181  				return false
   182  			}
   183  		}
   184  	}
   185  	return true
   186  }
   187  
   188  var triangularTests = []Triangular{
   189  	{N: 3, Stride: 3, Data: []float64{
   190  		1, 2, 3,
   191  		4, 5, 6,
   192  		7, 8, 9,
   193  	}},
   194  	{N: 3, Stride: 5, Data: []float64{
   195  		1, 2, 3, 0, 0,
   196  		4, 5, 6, 0, 0,
   197  		7, 8, 9, 0, 0,
   198  	}},
   199  }
   200  
   201  func TestConvertTriangular(t *testing.T) {
   202  	for _, test := range triangularTests {
   203  		for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower, blas.All} {
   204  			for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} {
   205  				test.Uplo = uplo
   206  				test.Diag = diag
   207  				colmajor := newTriangularColsFrom(test)
   208  				if !sameTriangular(colmajor, test) {
   209  					t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   210  						colmajor, test)
   211  				}
   212  				rowmajor := newTriangularFrom(colmajor)
   213  				if !sameTriangular(rowmajor, test) {
   214  					t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   215  						rowmajor, test)
   216  				}
   217  			}
   218  		}
   219  	}
   220  }
   221  
   222  func newBandFrom(a BandCols) Band {
   223  	t := Band{
   224  		Rows:   a.Rows,
   225  		Cols:   a.Cols,
   226  		KL:     a.KL,
   227  		KU:     a.KU,
   228  		Stride: a.KL + a.KU + 1,
   229  		Data:   make([]float64, a.Rows*(a.KL+a.KU+1)),
   230  	}
   231  	for i := range t.Data {
   232  		t.Data[i] = math.NaN()
   233  	}
   234  	t.From(a)
   235  	return t
   236  }
   237  
   238  func (m Band) dims() (r, c int) { return m.Rows, m.Cols }
   239  func (m Band) at(i, j int) float64 {
   240  	pj := j + m.KL - i
   241  	if pj < 0 || m.KL+m.KU+1 <= pj {
   242  		return 0
   243  	}
   244  	return m.Data[i*m.Stride+pj]
   245  }
   246  func (m Band) bandwidth() (kl, ku int) { return m.KL, m.KU }
   247  
   248  func newBandColsFrom(a Band) BandCols {
   249  	t := BandCols{
   250  		Rows:   a.Rows,
   251  		Cols:   a.Cols,
   252  		KL:     a.KL,
   253  		KU:     a.KU,
   254  		Stride: a.KL + a.KU + 1,
   255  		Data:   make([]float64, a.Cols*(a.KL+a.KU+1)),
   256  	}
   257  	for i := range t.Data {
   258  		t.Data[i] = math.NaN()
   259  	}
   260  	t.From(a)
   261  	return t
   262  }
   263  
   264  func (m BandCols) dims() (r, c int) { return m.Rows, m.Cols }
   265  func (m BandCols) at(i, j int) float64 {
   266  	pj := i + m.KU - j
   267  	if pj < 0 || m.KL+m.KU+1 <= pj {
   268  		return 0
   269  	}
   270  	return m.Data[j*m.Stride+pj]
   271  }
   272  func (m BandCols) bandwidth() (kl, ku int) { return m.KL, m.KU }
   273  
   274  type band interface {
   275  	dims() (r, c int)
   276  	at(i, j int) float64
   277  	bandwidth() (kl, ku int)
   278  }
   279  
   280  func sameBand(a, b band) bool {
   281  	ar, ac := a.dims()
   282  	br, bc := b.dims()
   283  	if ar != br || ac != bc {
   284  		return false
   285  	}
   286  	akl, aku := a.bandwidth()
   287  	bkl, bku := b.bandwidth()
   288  	if akl != bkl || aku != bku {
   289  		return false
   290  	}
   291  	for i := 0; i < ar; i++ {
   292  		for j := 0; j < ac; j++ {
   293  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   294  				return false
   295  			}
   296  		}
   297  	}
   298  	return true
   299  }
   300  
   301  var bandTests = []Band{
   302  	{Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 1, Data: []float64{
   303  		1,
   304  		2,
   305  		3,
   306  	}},
   307  	{Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []float64{
   308  		1,
   309  		2,
   310  		3,
   311  	}},
   312  	{Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []float64{
   313  		1,
   314  		2,
   315  		3,
   316  	}},
   317  	{Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 2, Data: []float64{
   318  		1, 2,
   319  		3, 4,
   320  		5, 6,
   321  	}},
   322  	{Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 2, Data: []float64{
   323  		1, 2,
   324  		3, 4,
   325  		5, 6,
   326  	}},
   327  	{Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 3, Data: []float64{
   328  		-1, 2, 3,
   329  		4, 5, 6,
   330  		7, 8, 9,
   331  	}},
   332  	{Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 3, Data: []float64{
   333  		-1, 2, 3,
   334  		4, 5, 6,
   335  		7, 8, -2,
   336  		9, -3, -4,
   337  	}},
   338  	{Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 4, Data: []float64{
   339  		-2, -1, 3, 4,
   340  		-3, 5, 6, 7,
   341  		8, 9, 10, 11,
   342  	}},
   343  	{Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 4, Data: []float64{
   344  		-2, -1, 2, 3,
   345  		-3, 4, 5, 6,
   346  		7, 8, 9, -4,
   347  		10, 11, -5, -6,
   348  	}},
   349  
   350  	{Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 5, Data: []float64{
   351  		1, 0, 0, 0, 0,
   352  		2, 0, 0, 0, 0,
   353  		3, 0, 0, 0, 0,
   354  	}},
   355  	{Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []float64{
   356  		1, 0, 0, 0, 0,
   357  		2, 0, 0, 0, 0,
   358  		3, 0, 0, 0, 0,
   359  	}},
   360  	{Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []float64{
   361  		1, 0, 0, 0, 0,
   362  		2, 0, 0, 0, 0,
   363  		3, 0, 0, 0, 0,
   364  	}},
   365  	{Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 5, Data: []float64{
   366  		1, 2, 0, 0, 0,
   367  		3, 4, 0, 0, 0,
   368  		5, 6, 0, 0, 0,
   369  	}},
   370  	{Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 5, Data: []float64{
   371  		1, 2, 0, 0, 0,
   372  		3, 4, 0, 0, 0,
   373  		5, 6, 0, 0, 0,
   374  	}},
   375  	{Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 5, Data: []float64{
   376  		-1, 2, 3, 0, 0,
   377  		4, 5, 6, 0, 0,
   378  		7, 8, 9, 0, 0,
   379  	}},
   380  	{Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 5, Data: []float64{
   381  		-1, 2, 3, 0, 0,
   382  		4, 5, 6, 0, 0,
   383  		7, 8, -2, 0, 0,
   384  		9, -3, -4, 0, 0,
   385  	}},
   386  	{Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 5, Data: []float64{
   387  		-2, -1, 3, 4, 0,
   388  		-3, 5, 6, 7, 0,
   389  		8, 9, 10, 11, 0,
   390  	}},
   391  	{Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 5, Data: []float64{
   392  		-2, -1, 2, 3, 0,
   393  		-3, 4, 5, 6, 0,
   394  		7, 8, 9, -4, 0,
   395  		10, 11, -5, -6, 0,
   396  	}},
   397  }
   398  
   399  func TestConvertBand(t *testing.T) {
   400  	for _, test := range bandTests {
   401  		colmajor := newBandColsFrom(test)
   402  		if !sameBand(colmajor, test) {
   403  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   404  				colmajor, test)
   405  		}
   406  		rowmajor := newBandFrom(colmajor)
   407  		if !sameBand(rowmajor, test) {
   408  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   409  				rowmajor, test)
   410  		}
   411  	}
   412  }
   413  
   414  func newTriangularBandFrom(a TriangularBandCols) TriangularBand {
   415  	t := TriangularBand{
   416  		N:      a.N,
   417  		K:      a.K,
   418  		Stride: a.K + 1,
   419  		Data:   make([]float64, a.N*(a.K+1)),
   420  		Uplo:   a.Uplo,
   421  		Diag:   a.Diag,
   422  	}
   423  	for i := range t.Data {
   424  		t.Data[i] = math.NaN()
   425  	}
   426  	t.From(a)
   427  	return t
   428  }
   429  
   430  func (m TriangularBand) n() (n int) { return m.N }
   431  func (m TriangularBand) at(i, j int) float64 {
   432  	if m.Diag == blas.Unit && i == j {
   433  		return 1
   434  	}
   435  	b := Band{
   436  		Rows: m.N, Cols: m.N,
   437  		Stride: m.Stride,
   438  		Data:   m.Data,
   439  	}
   440  	switch m.Uplo {
   441  	default:
   442  		panic("blas64: bad BLAS uplo")
   443  	case blas.Upper:
   444  		if i > j {
   445  			return 0
   446  		}
   447  		b.KU = m.K
   448  	case blas.Lower:
   449  		if i < j {
   450  			return 0
   451  		}
   452  		b.KL = m.K
   453  	}
   454  	return b.at(i, j)
   455  }
   456  func (m TriangularBand) bandwidth() (k int) { return m.K }
   457  func (m TriangularBand) uplo() blas.Uplo    { return m.Uplo }
   458  func (m TriangularBand) diag() blas.Diag    { return m.Diag }
   459  
   460  func newTriangularBandColsFrom(a TriangularBand) TriangularBandCols {
   461  	t := TriangularBandCols{
   462  		N:      a.N,
   463  		K:      a.K,
   464  		Stride: a.K + 1,
   465  		Data:   make([]float64, a.N*(a.K+1)),
   466  		Uplo:   a.Uplo,
   467  		Diag:   a.Diag,
   468  	}
   469  	for i := range t.Data {
   470  		t.Data[i] = math.NaN()
   471  	}
   472  	t.From(a)
   473  	return t
   474  }
   475  
   476  func (m TriangularBandCols) n() (n int) { return m.N }
   477  func (m TriangularBandCols) at(i, j int) float64 {
   478  	if m.Diag == blas.Unit && i == j {
   479  		return 1
   480  	}
   481  	b := BandCols{
   482  		Rows: m.N, Cols: m.N,
   483  		Stride: m.Stride,
   484  		Data:   m.Data,
   485  	}
   486  	switch m.Uplo {
   487  	default:
   488  		panic("blas64: bad BLAS uplo")
   489  	case blas.Upper:
   490  		if i > j {
   491  			return 0
   492  		}
   493  		b.KU = m.K
   494  	case blas.Lower:
   495  		if i < j {
   496  			return 0
   497  		}
   498  		b.KL = m.K
   499  	}
   500  	return b.at(i, j)
   501  }
   502  func (m TriangularBandCols) bandwidth() (k int) { return m.K }
   503  func (m TriangularBandCols) uplo() blas.Uplo    { return m.Uplo }
   504  func (m TriangularBandCols) diag() blas.Diag    { return m.Diag }
   505  
   506  type triangularBand interface {
   507  	n() (n int)
   508  	at(i, j int) float64
   509  	bandwidth() (k int)
   510  	uplo() blas.Uplo
   511  	diag() blas.Diag
   512  }
   513  
   514  func sameTriangularBand(a, b triangularBand) bool {
   515  	an := a.n()
   516  	bn := b.n()
   517  	if an != bn {
   518  		return false
   519  	}
   520  	if a.uplo() != b.uplo() {
   521  		return false
   522  	}
   523  	if a.diag() != b.diag() {
   524  		return false
   525  	}
   526  	ak := a.bandwidth()
   527  	bk := b.bandwidth()
   528  	if ak != bk {
   529  		return false
   530  	}
   531  	for i := 0; i < an; i++ {
   532  		for j := 0; j < an; j++ {
   533  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   534  				return false
   535  			}
   536  		}
   537  	}
   538  	return true
   539  }
   540  
   541  var triangularBandTests = []TriangularBand{
   542  	{N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float64{
   543  		1,
   544  		2,
   545  		3,
   546  	}},
   547  	{N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float64{
   548  		1,
   549  		2,
   550  		3,
   551  	}},
   552  	{N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float64{
   553  		1, 2,
   554  		3, 4,
   555  		5, -1,
   556  	}},
   557  	{N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float64{
   558  		-1, 1,
   559  		2, 3,
   560  		4, 5,
   561  	}},
   562  	{N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float64{
   563  		1, 2, 3,
   564  		4, 5, -1,
   565  		6, -2, -3,
   566  	}},
   567  	{N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float64{
   568  		-2, -1, 1,
   569  		-3, 2, 4,
   570  		3, 5, 6,
   571  	}},
   572  
   573  	{N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float64{
   574  		1, 0, 0, 0, 0,
   575  		2, 0, 0, 0, 0,
   576  		3, 0, 0, 0, 0,
   577  	}},
   578  	{N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float64{
   579  		1, 0, 0, 0, 0,
   580  		2, 0, 0, 0, 0,
   581  		3, 0, 0, 0, 0,
   582  	}},
   583  	{N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float64{
   584  		1, 2, 0, 0, 0,
   585  		3, 4, 0, 0, 0,
   586  		5, -1, 0, 0, 0,
   587  	}},
   588  	{N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float64{
   589  		-1, 1, 0, 0, 0,
   590  		2, 3, 0, 0, 0,
   591  		4, 5, 0, 0, 0,
   592  	}},
   593  	{N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float64{
   594  		1, 2, 3, 0, 0,
   595  		4, 5, -1, 0, 0,
   596  		6, -2, -3, 0, 0,
   597  	}},
   598  	{N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float64{
   599  		-2, -1, 1, 0, 0,
   600  		-3, 2, 4, 0, 0,
   601  		3, 5, 6, 0, 0,
   602  	}},
   603  }
   604  
   605  func TestConvertTriBand(t *testing.T) {
   606  	for _, test := range triangularBandTests {
   607  		colmajor := newTriangularBandColsFrom(test)
   608  		if !sameTriangularBand(colmajor, test) {
   609  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   610  				colmajor, test)
   611  		}
   612  		rowmajor := newTriangularBandFrom(colmajor)
   613  		if !sameTriangularBand(rowmajor, test) {
   614  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   615  				rowmajor, test)
   616  		}
   617  	}
   618  }