github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/cblas64/conv_test.go (about)

     1  // Code generated by "go generate github.com/jingcheng-WU/gonum/blas”; DO NOT EDIT.
     2  
     3  // Copyright ©2015 The Gonum Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package cblas64
     8  
     9  import (
    10  	math "github.com/jingcheng-WU/gonum/internal/cmplx64"
    11  	"testing"
    12  
    13  	"github.com/jingcheng-WU/gonum/blas"
    14  )
    15  
    16  func newGeneralFrom(a GeneralCols) General {
    17  	t := General{
    18  		Rows:   a.Rows,
    19  		Cols:   a.Cols,
    20  		Stride: a.Cols,
    21  		Data:   make([]complex64, a.Rows*a.Cols),
    22  	}
    23  	t.From(a)
    24  	return t
    25  }
    26  
    27  func (m General) dims() (r, c int)      { return m.Rows, m.Cols }
    28  func (m General) at(i, j int) complex64 { return m.Data[i*m.Stride+j] }
    29  
    30  func newGeneralColsFrom(a General) GeneralCols {
    31  	t := GeneralCols{
    32  		Rows:   a.Rows,
    33  		Cols:   a.Cols,
    34  		Stride: a.Rows,
    35  		Data:   make([]complex64, a.Rows*a.Cols),
    36  	}
    37  	t.From(a)
    38  	return t
    39  }
    40  
    41  func (m GeneralCols) dims() (r, c int)      { return m.Rows, m.Cols }
    42  func (m GeneralCols) at(i, j int) complex64 { return m.Data[i+j*m.Stride] }
    43  
    44  type general interface {
    45  	dims() (r, c int)
    46  	at(i, j int) complex64
    47  }
    48  
    49  func sameGeneral(a, b general) bool {
    50  	ar, ac := a.dims()
    51  	br, bc := b.dims()
    52  	if ar != br || ac != bc {
    53  		return false
    54  	}
    55  	for i := 0; i < ar; i++ {
    56  		for j := 0; j < ac; j++ {
    57  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
    58  				return false
    59  			}
    60  		}
    61  	}
    62  	return true
    63  }
    64  
    65  var generalTests = []General{
    66  	{Rows: 2, Cols: 3, Stride: 3, Data: []complex64{
    67  		1, 2, 3,
    68  		4, 5, 6,
    69  	}},
    70  	{Rows: 3, Cols: 2, Stride: 2, Data: []complex64{
    71  		1, 2,
    72  		3, 4,
    73  		5, 6,
    74  	}},
    75  	{Rows: 3, Cols: 3, Stride: 3, Data: []complex64{
    76  		1, 2, 3,
    77  		4, 5, 6,
    78  		7, 8, 9,
    79  	}},
    80  	{Rows: 2, Cols: 3, Stride: 5, Data: []complex64{
    81  		1, 2, 3, 0, 0,
    82  		4, 5, 6, 0, 0,
    83  	}},
    84  	{Rows: 3, Cols: 2, Stride: 5, Data: []complex64{
    85  		1, 2, 0, 0, 0,
    86  		3, 4, 0, 0, 0,
    87  		5, 6, 0, 0, 0,
    88  	}},
    89  	{Rows: 3, Cols: 3, Stride: 5, Data: []complex64{
    90  		1, 2, 3, 0, 0,
    91  		4, 5, 6, 0, 0,
    92  		7, 8, 9, 0, 0,
    93  	}},
    94  }
    95  
    96  func TestConvertGeneral(t *testing.T) {
    97  	for _, test := range generalTests {
    98  		colmajor := newGeneralColsFrom(test)
    99  		if !sameGeneral(colmajor, test) {
   100  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   101  				colmajor, test)
   102  		}
   103  		rowmajor := newGeneralFrom(colmajor)
   104  		if !sameGeneral(rowmajor, test) {
   105  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   106  				rowmajor, test)
   107  		}
   108  	}
   109  }
   110  
   111  func newTriangularFrom(a TriangularCols) Triangular {
   112  	t := Triangular{
   113  		N:      a.N,
   114  		Stride: a.N,
   115  		Data:   make([]complex64, a.N*a.N),
   116  		Diag:   a.Diag,
   117  		Uplo:   a.Uplo,
   118  	}
   119  	t.From(a)
   120  	return t
   121  }
   122  
   123  func (m Triangular) n() int { return m.N }
   124  func (m Triangular) at(i, j int) complex64 {
   125  	if m.Diag == blas.Unit && i == j {
   126  		return 1
   127  	}
   128  	if m.Uplo == blas.Lower && i < j && j < m.N {
   129  		return 0
   130  	}
   131  	if m.Uplo == blas.Upper && i > j {
   132  		return 0
   133  	}
   134  	return m.Data[i*m.Stride+j]
   135  }
   136  func (m Triangular) uplo() blas.Uplo { return m.Uplo }
   137  func (m Triangular) diag() blas.Diag { return m.Diag }
   138  
   139  func newTriangularColsFrom(a Triangular) TriangularCols {
   140  	t := TriangularCols{
   141  		N:      a.N,
   142  		Stride: a.N,
   143  		Data:   make([]complex64, a.N*a.N),
   144  		Diag:   a.Diag,
   145  		Uplo:   a.Uplo,
   146  	}
   147  	t.From(a)
   148  	return t
   149  }
   150  
   151  func (m TriangularCols) n() int { return m.N }
   152  func (m TriangularCols) at(i, j int) complex64 {
   153  	if m.Diag == blas.Unit && i == j {
   154  		return 1
   155  	}
   156  	if m.Uplo == blas.Lower && i < j {
   157  		return 0
   158  	}
   159  	if m.Uplo == blas.Upper && i > j && i < m.N {
   160  		return 0
   161  	}
   162  	return m.Data[i+j*m.Stride]
   163  }
   164  func (m TriangularCols) uplo() blas.Uplo { return m.Uplo }
   165  func (m TriangularCols) diag() blas.Diag { return m.Diag }
   166  
   167  type triangular interface {
   168  	n() int
   169  	at(i, j int) complex64
   170  	uplo() blas.Uplo
   171  	diag() blas.Diag
   172  }
   173  
   174  func sameTriangular(a, b triangular) bool {
   175  	an := a.n()
   176  	bn := b.n()
   177  	if an != bn {
   178  		return false
   179  	}
   180  	for i := 0; i < an; i++ {
   181  		for j := 0; j < an; j++ {
   182  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   183  				return false
   184  			}
   185  		}
   186  	}
   187  	return true
   188  }
   189  
   190  var triangularTests = []Triangular{
   191  	{N: 3, Stride: 3, Data: []complex64{
   192  		1, 2, 3,
   193  		4, 5, 6,
   194  		7, 8, 9,
   195  	}},
   196  	{N: 3, Stride: 5, Data: []complex64{
   197  		1, 2, 3, 0, 0,
   198  		4, 5, 6, 0, 0,
   199  		7, 8, 9, 0, 0,
   200  	}},
   201  }
   202  
   203  func TestConvertTriangular(t *testing.T) {
   204  	for _, test := range triangularTests {
   205  		for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower, blas.All} {
   206  			for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} {
   207  				test.Uplo = uplo
   208  				test.Diag = diag
   209  				colmajor := newTriangularColsFrom(test)
   210  				if !sameTriangular(colmajor, test) {
   211  					t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   212  						colmajor, test)
   213  				}
   214  				rowmajor := newTriangularFrom(colmajor)
   215  				if !sameTriangular(rowmajor, test) {
   216  					t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   217  						rowmajor, test)
   218  				}
   219  			}
   220  		}
   221  	}
   222  }
   223  
   224  func newBandFrom(a BandCols) Band {
   225  	t := Band{
   226  		Rows:   a.Rows,
   227  		Cols:   a.Cols,
   228  		KL:     a.KL,
   229  		KU:     a.KU,
   230  		Stride: a.KL + a.KU + 1,
   231  		Data:   make([]complex64, a.Rows*(a.KL+a.KU+1)),
   232  	}
   233  	for i := range t.Data {
   234  		t.Data[i] = math.NaN()
   235  	}
   236  	t.From(a)
   237  	return t
   238  }
   239  
   240  func (m Band) dims() (r, c int) { return m.Rows, m.Cols }
   241  func (m Band) at(i, j int) complex64 {
   242  	pj := j + m.KL - i
   243  	if pj < 0 || m.KL+m.KU+1 <= pj {
   244  		return 0
   245  	}
   246  	return m.Data[i*m.Stride+pj]
   247  }
   248  func (m Band) bandwidth() (kl, ku int) { return m.KL, m.KU }
   249  
   250  func newBandColsFrom(a Band) BandCols {
   251  	t := BandCols{
   252  		Rows:   a.Rows,
   253  		Cols:   a.Cols,
   254  		KL:     a.KL,
   255  		KU:     a.KU,
   256  		Stride: a.KL + a.KU + 1,
   257  		Data:   make([]complex64, a.Cols*(a.KL+a.KU+1)),
   258  	}
   259  	for i := range t.Data {
   260  		t.Data[i] = math.NaN()
   261  	}
   262  	t.From(a)
   263  	return t
   264  }
   265  
   266  func (m BandCols) dims() (r, c int) { return m.Rows, m.Cols }
   267  func (m BandCols) at(i, j int) complex64 {
   268  	pj := i + m.KU - j
   269  	if pj < 0 || m.KL+m.KU+1 <= pj {
   270  		return 0
   271  	}
   272  	return m.Data[j*m.Stride+pj]
   273  }
   274  func (m BandCols) bandwidth() (kl, ku int) { return m.KL, m.KU }
   275  
   276  type band interface {
   277  	dims() (r, c int)
   278  	at(i, j int) complex64
   279  	bandwidth() (kl, ku int)
   280  }
   281  
   282  func sameBand(a, b band) bool {
   283  	ar, ac := a.dims()
   284  	br, bc := b.dims()
   285  	if ar != br || ac != bc {
   286  		return false
   287  	}
   288  	akl, aku := a.bandwidth()
   289  	bkl, bku := b.bandwidth()
   290  	if akl != bkl || aku != bku {
   291  		return false
   292  	}
   293  	for i := 0; i < ar; i++ {
   294  		for j := 0; j < ac; j++ {
   295  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   296  				return false
   297  			}
   298  		}
   299  	}
   300  	return true
   301  }
   302  
   303  var bandTests = []Band{
   304  	{Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 1, Data: []complex64{
   305  		1,
   306  		2,
   307  		3,
   308  	}},
   309  	{Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []complex64{
   310  		1,
   311  		2,
   312  		3,
   313  	}},
   314  	{Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 1, Data: []complex64{
   315  		1,
   316  		2,
   317  		3,
   318  	}},
   319  	{Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 2, Data: []complex64{
   320  		1, 2,
   321  		3, 4,
   322  		5, 6,
   323  	}},
   324  	{Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 2, Data: []complex64{
   325  		1, 2,
   326  		3, 4,
   327  		5, 6,
   328  	}},
   329  	{Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 3, Data: []complex64{
   330  		-1, 2, 3,
   331  		4, 5, 6,
   332  		7, 8, 9,
   333  	}},
   334  	{Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 3, Data: []complex64{
   335  		-1, 2, 3,
   336  		4, 5, 6,
   337  		7, 8, -2,
   338  		9, -3, -4,
   339  	}},
   340  	{Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 4, Data: []complex64{
   341  		-2, -1, 3, 4,
   342  		-3, 5, 6, 7,
   343  		8, 9, 10, 11,
   344  	}},
   345  	{Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 4, Data: []complex64{
   346  		-2, -1, 2, 3,
   347  		-3, 4, 5, 6,
   348  		7, 8, 9, -4,
   349  		10, 11, -5, -6,
   350  	}},
   351  
   352  	{Rows: 3, Cols: 4, KL: 0, KU: 0, Stride: 5, Data: []complex64{
   353  		1, 0, 0, 0, 0,
   354  		2, 0, 0, 0, 0,
   355  		3, 0, 0, 0, 0,
   356  	}},
   357  	{Rows: 3, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []complex64{
   358  		1, 0, 0, 0, 0,
   359  		2, 0, 0, 0, 0,
   360  		3, 0, 0, 0, 0,
   361  	}},
   362  	{Rows: 4, Cols: 3, KL: 0, KU: 0, Stride: 5, Data: []complex64{
   363  		1, 0, 0, 0, 0,
   364  		2, 0, 0, 0, 0,
   365  		3, 0, 0, 0, 0,
   366  	}},
   367  	{Rows: 4, Cols: 3, KL: 0, KU: 1, Stride: 5, Data: []complex64{
   368  		1, 2, 0, 0, 0,
   369  		3, 4, 0, 0, 0,
   370  		5, 6, 0, 0, 0,
   371  	}},
   372  	{Rows: 3, Cols: 4, KL: 0, KU: 1, Stride: 5, Data: []complex64{
   373  		1, 2, 0, 0, 0,
   374  		3, 4, 0, 0, 0,
   375  		5, 6, 0, 0, 0,
   376  	}},
   377  	{Rows: 3, Cols: 4, KL: 1, KU: 1, Stride: 5, Data: []complex64{
   378  		-1, 2, 3, 0, 0,
   379  		4, 5, 6, 0, 0,
   380  		7, 8, 9, 0, 0,
   381  	}},
   382  	{Rows: 4, Cols: 3, KL: 1, KU: 1, Stride: 5, Data: []complex64{
   383  		-1, 2, 3, 0, 0,
   384  		4, 5, 6, 0, 0,
   385  		7, 8, -2, 0, 0,
   386  		9, -3, -4, 0, 0,
   387  	}},
   388  	{Rows: 3, Cols: 4, KL: 2, KU: 1, Stride: 5, Data: []complex64{
   389  		-2, -1, 3, 4, 0,
   390  		-3, 5, 6, 7, 0,
   391  		8, 9, 10, 11, 0,
   392  	}},
   393  	{Rows: 4, Cols: 3, KL: 2, KU: 1, Stride: 5, Data: []complex64{
   394  		-2, -1, 2, 3, 0,
   395  		-3, 4, 5, 6, 0,
   396  		7, 8, 9, -4, 0,
   397  		10, 11, -5, -6, 0,
   398  	}},
   399  }
   400  
   401  func TestConvertBand(t *testing.T) {
   402  	for _, test := range bandTests {
   403  		colmajor := newBandColsFrom(test)
   404  		if !sameBand(colmajor, test) {
   405  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   406  				colmajor, test)
   407  		}
   408  		rowmajor := newBandFrom(colmajor)
   409  		if !sameBand(rowmajor, test) {
   410  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   411  				rowmajor, test)
   412  		}
   413  	}
   414  }
   415  
   416  func newTriangularBandFrom(a TriangularBandCols) TriangularBand {
   417  	t := TriangularBand{
   418  		N:      a.N,
   419  		K:      a.K,
   420  		Stride: a.K + 1,
   421  		Data:   make([]complex64, a.N*(a.K+1)),
   422  		Uplo:   a.Uplo,
   423  		Diag:   a.Diag,
   424  	}
   425  	for i := range t.Data {
   426  		t.Data[i] = math.NaN()
   427  	}
   428  	t.From(a)
   429  	return t
   430  }
   431  
   432  func (m TriangularBand) n() (n int) { return m.N }
   433  func (m TriangularBand) at(i, j int) complex64 {
   434  	if m.Diag == blas.Unit && i == j {
   435  		return 1
   436  	}
   437  	b := Band{
   438  		Rows: m.N, Cols: m.N,
   439  		Stride: m.Stride,
   440  		Data:   m.Data,
   441  	}
   442  	switch m.Uplo {
   443  	default:
   444  		panic("cblas64: bad BLAS uplo")
   445  	case blas.Upper:
   446  		if i > j {
   447  			return 0
   448  		}
   449  		b.KU = m.K
   450  	case blas.Lower:
   451  		if i < j {
   452  			return 0
   453  		}
   454  		b.KL = m.K
   455  	}
   456  	return b.at(i, j)
   457  }
   458  func (m TriangularBand) bandwidth() (k int) { return m.K }
   459  func (m TriangularBand) uplo() blas.Uplo    { return m.Uplo }
   460  func (m TriangularBand) diag() blas.Diag    { return m.Diag }
   461  
   462  func newTriangularBandColsFrom(a TriangularBand) TriangularBandCols {
   463  	t := TriangularBandCols{
   464  		N:      a.N,
   465  		K:      a.K,
   466  		Stride: a.K + 1,
   467  		Data:   make([]complex64, a.N*(a.K+1)),
   468  		Uplo:   a.Uplo,
   469  		Diag:   a.Diag,
   470  	}
   471  	for i := range t.Data {
   472  		t.Data[i] = math.NaN()
   473  	}
   474  	t.From(a)
   475  	return t
   476  }
   477  
   478  func (m TriangularBandCols) n() (n int) { return m.N }
   479  func (m TriangularBandCols) at(i, j int) complex64 {
   480  	if m.Diag == blas.Unit && i == j {
   481  		return 1
   482  	}
   483  	b := BandCols{
   484  		Rows: m.N, Cols: m.N,
   485  		Stride: m.Stride,
   486  		Data:   m.Data,
   487  	}
   488  	switch m.Uplo {
   489  	default:
   490  		panic("cblas64: bad BLAS uplo")
   491  	case blas.Upper:
   492  		if i > j {
   493  			return 0
   494  		}
   495  		b.KU = m.K
   496  	case blas.Lower:
   497  		if i < j {
   498  			return 0
   499  		}
   500  		b.KL = m.K
   501  	}
   502  	return b.at(i, j)
   503  }
   504  func (m TriangularBandCols) bandwidth() (k int) { return m.K }
   505  func (m TriangularBandCols) uplo() blas.Uplo    { return m.Uplo }
   506  func (m TriangularBandCols) diag() blas.Diag    { return m.Diag }
   507  
   508  type triangularBand interface {
   509  	n() (n int)
   510  	at(i, j int) complex64
   511  	bandwidth() (k int)
   512  	uplo() blas.Uplo
   513  	diag() blas.Diag
   514  }
   515  
   516  func sameTriangularBand(a, b triangularBand) bool {
   517  	an := a.n()
   518  	bn := b.n()
   519  	if an != bn {
   520  		return false
   521  	}
   522  	if a.uplo() != b.uplo() {
   523  		return false
   524  	}
   525  	if a.diag() != b.diag() {
   526  		return false
   527  	}
   528  	ak := a.bandwidth()
   529  	bk := b.bandwidth()
   530  	if ak != bk {
   531  		return false
   532  	}
   533  	for i := 0; i < an; i++ {
   534  		for j := 0; j < an; j++ {
   535  			if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
   536  				return false
   537  			}
   538  		}
   539  	}
   540  	return true
   541  }
   542  
   543  var triangularBandTests = []TriangularBand{
   544  	{N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []complex64{
   545  		1,
   546  		2,
   547  		3,
   548  	}},
   549  	{N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []complex64{
   550  		1,
   551  		2,
   552  		3,
   553  	}},
   554  	{N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []complex64{
   555  		1, 2,
   556  		3, 4,
   557  		5, -1,
   558  	}},
   559  	{N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []complex64{
   560  		-1, 1,
   561  		2, 3,
   562  		4, 5,
   563  	}},
   564  	{N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []complex64{
   565  		1, 2, 3,
   566  		4, 5, -1,
   567  		6, -2, -3,
   568  	}},
   569  	{N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []complex64{
   570  		-2, -1, 1,
   571  		-3, 2, 4,
   572  		3, 5, 6,
   573  	}},
   574  
   575  	{N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []complex64{
   576  		1, 0, 0, 0, 0,
   577  		2, 0, 0, 0, 0,
   578  		3, 0, 0, 0, 0,
   579  	}},
   580  	{N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []complex64{
   581  		1, 0, 0, 0, 0,
   582  		2, 0, 0, 0, 0,
   583  		3, 0, 0, 0, 0,
   584  	}},
   585  	{N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []complex64{
   586  		1, 2, 0, 0, 0,
   587  		3, 4, 0, 0, 0,
   588  		5, -1, 0, 0, 0,
   589  	}},
   590  	{N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []complex64{
   591  		-1, 1, 0, 0, 0,
   592  		2, 3, 0, 0, 0,
   593  		4, 5, 0, 0, 0,
   594  	}},
   595  	{N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []complex64{
   596  		1, 2, 3, 0, 0,
   597  		4, 5, -1, 0, 0,
   598  		6, -2, -3, 0, 0,
   599  	}},
   600  	{N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []complex64{
   601  		-2, -1, 1, 0, 0,
   602  		-3, 2, 4, 0, 0,
   603  		3, 5, 6, 0, 0,
   604  	}},
   605  }
   606  
   607  func TestConvertTriBand(t *testing.T) {
   608  	for _, test := range triangularBandTests {
   609  		colmajor := newTriangularBandColsFrom(test)
   610  		if !sameTriangularBand(colmajor, test) {
   611  			t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
   612  				colmajor, test)
   613  		}
   614  		rowmajor := newTriangularBandFrom(colmajor)
   615  		if !sameTriangularBand(rowmajor, test) {
   616  			t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
   617  				rowmajor, test)
   618  		}
   619  	}
   620  }