gonum.org/v1/gonum@v0.14.0/blas/testblas/common.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"math"
     9  	"math/cmplx"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/floats/scalar"
    16  )
    17  
    18  // throwPanic will throw unexpected panics if true, or will just report them as errors if false
    19  const throwPanic = true
    20  
    21  var znan = cmplx.NaN()
    22  
    23  func dTolEqual(a, b float64) bool {
    24  	if math.IsNaN(a) && math.IsNaN(b) {
    25  		return true
    26  	}
    27  	if a == b {
    28  		return true
    29  	}
    30  	m := math.Max(math.Abs(a), math.Abs(b))
    31  	if m > 1 {
    32  		a /= m
    33  		b /= m
    34  	}
    35  	if math.Abs(a-b) < 1e-14 {
    36  		return true
    37  	}
    38  	return false
    39  }
    40  
    41  func dSliceTolEqual(a, b []float64) bool {
    42  	if len(a) != len(b) {
    43  		return false
    44  	}
    45  	for i := range a {
    46  		if !dTolEqual(a[i], b[i]) {
    47  			return false
    48  		}
    49  	}
    50  	return true
    51  }
    52  
    53  func dStridedSliceTolEqual(n int, a []float64, inca int, b []float64, incb int) bool {
    54  	ia := 0
    55  	ib := 0
    56  	if inca <= 0 {
    57  		ia = -(n - 1) * inca
    58  	}
    59  	if incb <= 0 {
    60  		ib = -(n - 1) * incb
    61  	}
    62  	for i := 0; i < n; i++ {
    63  		if !dTolEqual(a[ia], b[ib]) {
    64  			return false
    65  		}
    66  		ia += inca
    67  		ib += incb
    68  	}
    69  	return true
    70  }
    71  
    72  func dSliceEqual(a, b []float64) bool {
    73  	if len(a) != len(b) {
    74  		return false
    75  	}
    76  	for i := range a {
    77  		if !dTolEqual(a[i], b[i]) {
    78  			return false
    79  		}
    80  	}
    81  	return true
    82  }
    83  
    84  func dCopyTwoTmp(x, xTmp, y, yTmp []float64) {
    85  	if len(x) != len(xTmp) {
    86  		panic("x size mismatch")
    87  	}
    88  	if len(y) != len(yTmp) {
    89  		panic("y size mismatch")
    90  	}
    91  	copy(xTmp, x)
    92  	copy(yTmp, y)
    93  }
    94  
    95  // returns true if the function panics
    96  func panics(f func()) (b bool) {
    97  	defer func() {
    98  		err := recover()
    99  		if err != nil {
   100  			b = true
   101  		}
   102  	}()
   103  	f()
   104  	return
   105  }
   106  
   107  func testpanics(f func(), name string, t *testing.T) {
   108  	b := panics(f)
   109  	if !b {
   110  		t.Errorf("%v should panic and does not", name)
   111  	}
   112  }
   113  
   114  func sliceOfSliceCopy(a [][]float64) [][]float64 {
   115  	n := make([][]float64, len(a))
   116  	for i := range a {
   117  		n[i] = make([]float64, len(a[i]))
   118  		copy(n[i], a[i])
   119  	}
   120  	return n
   121  }
   122  
   123  func sliceCopy(a []float64) []float64 {
   124  	n := make([]float64, len(a))
   125  	copy(n, a)
   126  	return n
   127  }
   128  
   129  func flatten(a [][]float64) []float64 {
   130  	if len(a) == 0 {
   131  		return nil
   132  	}
   133  	m := len(a)
   134  	n := len(a[0])
   135  	s := make([]float64, m*n)
   136  	for i := 0; i < m; i++ {
   137  		for j := 0; j < n; j++ {
   138  			s[i*n+j] = a[i][j]
   139  		}
   140  	}
   141  	return s
   142  }
   143  
   144  func unflatten(a []float64, m, n int) [][]float64 {
   145  	s := make([][]float64, m)
   146  	for i := 0; i < m; i++ {
   147  		s[i] = make([]float64, n)
   148  		for j := 0; j < n; j++ {
   149  			s[i][j] = a[i*n+j]
   150  		}
   151  	}
   152  	return s
   153  }
   154  
   155  // flattenTriangular turns the upper or lower triangle of a dense slice of slice
   156  // into a single slice with packed storage. a must be a square matrix.
   157  func flattenTriangular(a [][]float64, ul blas.Uplo) []float64 {
   158  	m := len(a)
   159  	aFlat := make([]float64, m*(m+1)/2)
   160  	var k int
   161  	if ul == blas.Upper {
   162  		for i := 0; i < m; i++ {
   163  			k += copy(aFlat[k:], a[i][i:])
   164  		}
   165  		return aFlat
   166  	}
   167  	for i := 0; i < m; i++ {
   168  		k += copy(aFlat[k:], a[i][:i+1])
   169  	}
   170  	return aFlat
   171  }
   172  
   173  // flattenBanded turns a dense banded slice of slice into the compact banded matrix format
   174  func flattenBanded(a [][]float64, ku, kl int) []float64 {
   175  	m := len(a)
   176  	n := len(a[0])
   177  	if ku < 0 || kl < 0 {
   178  		panic("testblas: negative band length")
   179  	}
   180  	nRows := m
   181  	nCols := (ku + kl + 1)
   182  	aflat := make([]float64, nRows*nCols)
   183  	for i := range aflat {
   184  		aflat[i] = math.NaN()
   185  	}
   186  	// loop over the rows, and then the bands
   187  	// elements in the ith row stay in the ith row
   188  	// order in bands is kept
   189  	for i := 0; i < nRows; i++ {
   190  		min := -kl
   191  		if i-kl < 0 {
   192  			min = -i
   193  		}
   194  		max := ku
   195  		if i+ku >= n {
   196  			max = n - i - 1
   197  		}
   198  		for j := min; j <= max; j++ {
   199  			col := kl + j
   200  			aflat[i*nCols+col] = a[i][i+j]
   201  		}
   202  	}
   203  	return aflat
   204  }
   205  
   206  // makeIncremented takes a float64 slice with inc == 1 and makes an incremented version
   207  // and adds extra values on the end
   208  func makeIncremented(x []float64, inc int, extra int) []float64 {
   209  	if inc == 0 {
   210  		panic("zero inc")
   211  	}
   212  	absinc := inc
   213  	if absinc < 0 {
   214  		absinc = -inc
   215  	}
   216  	xcopy := make([]float64, len(x))
   217  	if inc > 0 {
   218  		copy(xcopy, x)
   219  	} else {
   220  		for i := 0; i < len(x); i++ {
   221  			xcopy[i] = x[len(x)-i-1]
   222  		}
   223  	}
   224  
   225  	// don't use NaN because it makes comparison hard
   226  	// Do use a weird unique value for easier debugging
   227  	counter := 100.0
   228  	var xnew []float64
   229  	for i, v := range xcopy {
   230  		xnew = append(xnew, v)
   231  		if i != len(x)-1 {
   232  			for j := 0; j < absinc-1; j++ {
   233  				xnew = append(xnew, counter)
   234  				counter++
   235  			}
   236  		}
   237  	}
   238  	for i := 0; i < extra; i++ {
   239  		xnew = append(xnew, counter)
   240  		counter++
   241  	}
   242  	return xnew
   243  }
   244  
   245  // makeIncremented32 takes a float32 slice with inc == 1 and makes an incremented version
   246  // and adds extra values on the end
   247  func makeIncremented32(x []float32, inc int, extra int) []float32 {
   248  	if inc == 0 {
   249  		panic("zero inc")
   250  	}
   251  	absinc := inc
   252  	if absinc < 0 {
   253  		absinc = -inc
   254  	}
   255  	xcopy := make([]float32, len(x))
   256  	if inc > 0 {
   257  		copy(xcopy, x)
   258  	} else {
   259  		for i := 0; i < len(x); i++ {
   260  			xcopy[i] = x[len(x)-i-1]
   261  		}
   262  	}
   263  
   264  	// don't use NaN because it makes comparison hard
   265  	// Do use a weird unique value for easier debugging
   266  	var counter float32 = 100.0
   267  	var xnew []float32
   268  	for i, v := range xcopy {
   269  		xnew = append(xnew, v)
   270  		if i != len(x)-1 {
   271  			for j := 0; j < absinc-1; j++ {
   272  				xnew = append(xnew, counter)
   273  				counter++
   274  			}
   275  		}
   276  	}
   277  	for i := 0; i < extra; i++ {
   278  		xnew = append(xnew, counter)
   279  		counter++
   280  	}
   281  	return xnew
   282  }
   283  
   284  func abs(x int) int {
   285  	if x < 0 {
   286  		return -x
   287  	}
   288  	return x
   289  }
   290  
   291  func allPairs(x, y []int) [][2]int {
   292  	var p [][2]int
   293  	for _, v0 := range x {
   294  		for _, v1 := range y {
   295  			p = append(p, [2]int{v0, v1})
   296  		}
   297  	}
   298  	return p
   299  }
   300  
   301  func sameFloat64(a, b float64) bool {
   302  	return a == b || math.IsNaN(a) && math.IsNaN(b)
   303  }
   304  
   305  func sameComplex128(x, y complex128) bool {
   306  	return sameFloat64(real(x), real(y)) && sameFloat64(imag(x), imag(y))
   307  }
   308  
   309  func zsame(x, y []complex128) bool {
   310  	if len(x) != len(y) {
   311  		return false
   312  	}
   313  	for i, v := range x {
   314  		w := y[i]
   315  		if !sameComplex128(v, w) {
   316  			return false
   317  		}
   318  	}
   319  	return true
   320  }
   321  
   322  // zSameAtNonstrided returns whether elements at non-stride positions of vectors
   323  // x and y are same.
   324  func zSameAtNonstrided(x, y []complex128, inc int) bool {
   325  	if len(x) != len(y) {
   326  		return false
   327  	}
   328  	if inc < 0 {
   329  		inc = -inc
   330  	}
   331  	for i, v := range x {
   332  		if i%inc == 0 {
   333  			continue
   334  		}
   335  		w := y[i]
   336  		if !sameComplex128(v, w) {
   337  			return false
   338  		}
   339  	}
   340  	return true
   341  }
   342  
   343  // zEqualApproxAtStrided returns whether elements at stride positions of vectors
   344  // x and y are approximately equal within tol.
   345  func zEqualApproxAtStrided(x, y []complex128, inc int, tol float64) bool {
   346  	if len(x) != len(y) {
   347  		return false
   348  	}
   349  	if inc < 0 {
   350  		inc = -inc
   351  	}
   352  	for i := 0; i < len(x); i += inc {
   353  		v := x[i]
   354  		w := y[i]
   355  		if !(cmplx.Abs(v-w) <= tol) {
   356  			return false
   357  		}
   358  	}
   359  	return true
   360  }
   361  
   362  func makeZVector(data []complex128, inc int) []complex128 {
   363  	if inc == 0 {
   364  		panic("bad test")
   365  	}
   366  	if len(data) == 0 {
   367  		return nil
   368  	}
   369  	inc = abs(inc)
   370  	x := make([]complex128, (len(data)-1)*inc+1)
   371  	for i := range x {
   372  		x[i] = znan
   373  	}
   374  	for i, v := range data {
   375  		x[i*inc] = v
   376  	}
   377  	return x
   378  }
   379  
   380  func makeZGeneral(data []complex128, m, n int, ld int) []complex128 {
   381  	if m < 0 || n < 0 {
   382  		panic("bad test")
   383  	}
   384  	if data != nil && len(data) != m*n {
   385  		panic("bad test")
   386  	}
   387  	if ld < max(1, n) {
   388  		panic("bad test")
   389  	}
   390  	if m == 0 || n == 0 {
   391  		return nil
   392  	}
   393  	a := make([]complex128, (m-1)*ld+n)
   394  	for i := range a {
   395  		a[i] = znan
   396  	}
   397  	if data != nil {
   398  		for i := 0; i < m; i++ {
   399  			copy(a[i*ld:i*ld+n], data[i*n:i*n+n])
   400  		}
   401  	}
   402  	return a
   403  }
   404  
   405  func max(a, b int) int {
   406  	if a < b {
   407  		return b
   408  	}
   409  	return a
   410  }
   411  
   412  func min(a, b int) int {
   413  	if a < b {
   414  		return a
   415  	}
   416  	return b
   417  }
   418  
   419  // zPack returns the uplo triangle of an n×n matrix A in packed format.
   420  func zPack(uplo blas.Uplo, n int, a []complex128, lda int) []complex128 {
   421  	if n == 0 {
   422  		return nil
   423  	}
   424  	ap := make([]complex128, n*(n+1)/2)
   425  	var ii int
   426  	if uplo == blas.Upper {
   427  		for i := 0; i < n; i++ {
   428  			for j := i; j < n; j++ {
   429  				ap[ii] = a[i*lda+j]
   430  				ii++
   431  			}
   432  		}
   433  	} else {
   434  		for i := 0; i < n; i++ {
   435  			for j := 0; j <= i; j++ {
   436  				ap[ii] = a[i*lda+j]
   437  				ii++
   438  			}
   439  		}
   440  	}
   441  	return ap
   442  }
   443  
   444  // zUnpackAsHermitian returns an n×n general Hermitian matrix (with stride n)
   445  // whose packed uplo triangle is stored on entry in ap.
   446  func zUnpackAsHermitian(uplo blas.Uplo, n int, ap []complex128) []complex128 {
   447  	if n == 0 {
   448  		return nil
   449  	}
   450  	a := make([]complex128, n*n)
   451  	lda := n
   452  	var ii int
   453  	if uplo == blas.Upper {
   454  		for i := 0; i < n; i++ {
   455  			for j := i; j < n; j++ {
   456  				a[i*lda+j] = ap[ii]
   457  				if i != j {
   458  					a[j*lda+i] = cmplx.Conj(ap[ii])
   459  				}
   460  				ii++
   461  			}
   462  		}
   463  	} else {
   464  		for i := 0; i < n; i++ {
   465  			for j := 0; j <= i; j++ {
   466  				a[i*lda+j] = ap[ii]
   467  				if i != j {
   468  					a[j*lda+i] = cmplx.Conj(ap[ii])
   469  				}
   470  				ii++
   471  			}
   472  		}
   473  	}
   474  	return a
   475  }
   476  
   477  // zPackBand returns the (kL+1+kU) band of an m×n general matrix A in band
   478  // matrix format with ldab stride. Out-of-range elements are filled with NaN.
   479  func zPackBand(kL, kU, ldab int, m, n int, a []complex128, lda int) []complex128 {
   480  	if m == 0 || n == 0 {
   481  		return nil
   482  	}
   483  	nRow := min(m, n+kL)
   484  	ab := make([]complex128, (nRow-1)*ldab+kL+1+kU)
   485  	for i := range ab {
   486  		ab[i] = znan
   487  	}
   488  	for i := 0; i < m; i++ {
   489  		off := max(0, kL-i)
   490  		var k int
   491  		for j := max(0, i-kL); j < min(n, i+kU+1); j++ {
   492  			ab[i*ldab+off+k] = a[i*lda+j]
   493  			k++
   494  		}
   495  	}
   496  	return ab
   497  }
   498  
   499  // zPackTriBand returns in band matrix format the (k+1) band in the uplo
   500  // triangle of an n×n matrix A. Out-of-range elements are filled with NaN.
   501  func zPackTriBand(k, ldab int, uplo blas.Uplo, n int, a []complex128, lda int) []complex128 {
   502  	if n == 0 {
   503  		return nil
   504  	}
   505  	ab := make([]complex128, (n-1)*ldab+k+1)
   506  	for i := range ab {
   507  		ab[i] = znan
   508  	}
   509  	if uplo == blas.Upper {
   510  		for i := 0; i < n; i++ {
   511  			var k int
   512  			for j := i; j < min(n, i+k+1); j++ {
   513  				ab[i*ldab+k] = a[i*lda+j]
   514  				k++
   515  			}
   516  		}
   517  	} else {
   518  		for i := 0; i < n; i++ {
   519  			off := max(0, k-i)
   520  			var kk int
   521  			for j := max(0, i-k); j <= i; j++ {
   522  				ab[i*ldab+off+kk] = a[i*lda+j]
   523  				kk++
   524  			}
   525  		}
   526  	}
   527  	return ab
   528  }
   529  
   530  // zEqualApprox returns whether the slices a and b are approximately equal.
   531  func zEqualApprox(a, b []complex128, tol float64) bool {
   532  	if len(a) != len(b) {
   533  		panic("mismatched slice length")
   534  	}
   535  	for i, ai := range a {
   536  		if !scalar.EqualWithinAbs(cmplx.Abs(ai), cmplx.Abs(b[i]), tol) {
   537  			return false
   538  		}
   539  	}
   540  	return true
   541  }
   542  
   543  // rndComplex128 returns a complex128 with random components.
   544  func rndComplex128(rnd *rand.Rand) complex128 {
   545  	return complex(rnd.NormFloat64(), rnd.NormFloat64())
   546  }
   547  
   548  // zmm returns the result of one of the matrix-matrix operations
   549  //
   550  //	alpha * op(A) * op(B) + beta * C
   551  //
   552  // where op(X) is one of
   553  //
   554  //	op(X) = X  or  op(X) = Xᵀ  or  op(X) = Xᴴ,
   555  //
   556  // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
   557  // op(B) a k×n matrix and C an m×n matrix.
   558  //
   559  // The returned slice is newly allocated, has the same length as c and the
   560  // matrix it represents has the stride ldc. Out-of-range elements are equal to
   561  // those of C to ease comparison of results from BLAS Level 3 functions.
   562  func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) []complex128 {
   563  	r := make([]complex128, len(c))
   564  	copy(r, c)
   565  	for i := 0; i < m; i++ {
   566  		for j := 0; j < n; j++ {
   567  			r[i*ldc+j] = 0
   568  		}
   569  	}
   570  	switch tA {
   571  	case blas.NoTrans:
   572  		switch tB {
   573  		case blas.NoTrans:
   574  			for i := 0; i < m; i++ {
   575  				for j := 0; j < n; j++ {
   576  					for l := 0; l < k; l++ {
   577  						r[i*ldc+j] += a[i*lda+l] * b[l*ldb+j]
   578  					}
   579  				}
   580  			}
   581  		case blas.Trans:
   582  			for i := 0; i < m; i++ {
   583  				for j := 0; j < n; j++ {
   584  					for l := 0; l < k; l++ {
   585  						r[i*ldc+j] += a[i*lda+l] * b[j*ldb+l]
   586  					}
   587  				}
   588  			}
   589  		case blas.ConjTrans:
   590  			for i := 0; i < m; i++ {
   591  				for j := 0; j < n; j++ {
   592  					for l := 0; l < k; l++ {
   593  						r[i*ldc+j] += a[i*lda+l] * cmplx.Conj(b[j*ldb+l])
   594  					}
   595  				}
   596  			}
   597  		}
   598  	case blas.Trans:
   599  		switch tB {
   600  		case blas.NoTrans:
   601  			for i := 0; i < m; i++ {
   602  				for j := 0; j < n; j++ {
   603  					for l := 0; l < k; l++ {
   604  						r[i*ldc+j] += a[l*lda+i] * b[l*ldb+j]
   605  					}
   606  				}
   607  			}
   608  		case blas.Trans:
   609  			for i := 0; i < m; i++ {
   610  				for j := 0; j < n; j++ {
   611  					for l := 0; l < k; l++ {
   612  						r[i*ldc+j] += a[l*lda+i] * b[j*ldb+l]
   613  					}
   614  				}
   615  			}
   616  		case blas.ConjTrans:
   617  			for i := 0; i < m; i++ {
   618  				for j := 0; j < n; j++ {
   619  					for l := 0; l < k; l++ {
   620  						r[i*ldc+j] += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
   621  					}
   622  				}
   623  			}
   624  		}
   625  	case blas.ConjTrans:
   626  		switch tB {
   627  		case blas.NoTrans:
   628  			for i := 0; i < m; i++ {
   629  				for j := 0; j < n; j++ {
   630  					for l := 0; l < k; l++ {
   631  						r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
   632  					}
   633  				}
   634  			}
   635  		case blas.Trans:
   636  			for i := 0; i < m; i++ {
   637  				for j := 0; j < n; j++ {
   638  					for l := 0; l < k; l++ {
   639  						r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
   640  					}
   641  				}
   642  			}
   643  		case blas.ConjTrans:
   644  			for i := 0; i < m; i++ {
   645  				for j := 0; j < n; j++ {
   646  					for l := 0; l < k; l++ {
   647  						r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
   648  					}
   649  				}
   650  			}
   651  		}
   652  	}
   653  	for i := 0; i < m; i++ {
   654  		for j := 0; j < n; j++ {
   655  			r[i*ldc+j] = alpha * r[i*ldc+j]
   656  			if beta != 0 {
   657  				r[i*ldc+j] += beta * c[i*ldc+j]
   658  			}
   659  		}
   660  	}
   661  	return r
   662  }
   663  
   664  // transString returns a string representation of blas.Transpose.
   665  func transString(t blas.Transpose) string {
   666  	switch t {
   667  	case blas.NoTrans:
   668  		return "NoTrans"
   669  	case blas.Trans:
   670  		return "Trans"
   671  	case blas.ConjTrans:
   672  		return "ConjTrans"
   673  	}
   674  	return "unknown trans"
   675  }
   676  
   677  // uploString returns a string representation of blas.Uplo.
   678  func uploString(uplo blas.Uplo) string {
   679  	switch uplo {
   680  	case blas.Lower:
   681  		return "Lower"
   682  	case blas.Upper:
   683  		return "Upper"
   684  	}
   685  	return "unknown uplo"
   686  }
   687  
   688  // sideString returns a string representation of blas.Side.
   689  func sideString(side blas.Side) string {
   690  	switch side {
   691  	case blas.Left:
   692  		return "Left"
   693  	case blas.Right:
   694  		return "Right"
   695  	}
   696  	return "unknown side"
   697  }
   698  
   699  // diagString returns a string representation of blas.Diag.
   700  func diagString(diag blas.Diag) string {
   701  	switch diag {
   702  	case blas.Unit:
   703  		return "Unit"
   704  	case blas.NonUnit:
   705  		return "NonUnit"
   706  	}
   707  	return "unknown diag"
   708  }
   709  
   710  // zSameLowerTri returns whether n×n matrices A and B are same under the diagonal.
   711  func zSameLowerTri(n int, a []complex128, lda int, b []complex128, ldb int) bool {
   712  	for i := 1; i < n; i++ {
   713  		for j := 0; j < i; j++ {
   714  			aij := a[i*lda+j]
   715  			bij := b[i*ldb+j]
   716  			if !sameComplex128(aij, bij) {
   717  				return false
   718  			}
   719  		}
   720  	}
   721  	return true
   722  }
   723  
   724  // zSameUpperTri returns whether n×n matrices A and B are same above the diagonal.
   725  func zSameUpperTri(n int, a []complex128, lda int, b []complex128, ldb int) bool {
   726  	for i := 0; i < n-1; i++ {
   727  		for j := i + 1; j < n; j++ {
   728  			aij := a[i*lda+j]
   729  			bij := b[i*ldb+j]
   730  			if !sameComplex128(aij, bij) {
   731  				return false
   732  			}
   733  		}
   734  	}
   735  	return true
   736  }