gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/blas/testblas/common.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"math"
     9  	"math/cmplx"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/floats/scalar"
    16  )
    17  
    18  // throwPanic will throw unexpected panics if true, or will just report them as errors if false
    19  const throwPanic = true
    20  
    21  var znan = cmplx.NaN()
    22  
    23  func dTolEqual(a, b float64) bool {
    24  	if math.IsNaN(a) && math.IsNaN(b) {
    25  		return true
    26  	}
    27  	if a == b {
    28  		return true
    29  	}
    30  	m := math.Max(math.Abs(a), math.Abs(b))
    31  	if m > 1 {
    32  		a /= m
    33  		b /= m
    34  	}
    35  	if math.Abs(a-b) < 1e-14 {
    36  		return true
    37  	}
    38  	return false
    39  }
    40  
    41  func dSliceTolEqual(a, b []float64) bool {
    42  	if len(a) != len(b) {
    43  		return false
    44  	}
    45  	for i := range a {
    46  		if !dTolEqual(a[i], b[i]) {
    47  			return false
    48  		}
    49  	}
    50  	return true
    51  }
    52  
    53  func dStridedSliceTolEqual(n int, a []float64, inca int, b []float64, incb int) bool {
    54  	ia := 0
    55  	ib := 0
    56  	if inca <= 0 {
    57  		ia = -(n - 1) * inca
    58  	}
    59  	if incb <= 0 {
    60  		ib = -(n - 1) * incb
    61  	}
    62  	for i := 0; i < n; i++ {
    63  		if !dTolEqual(a[ia], b[ib]) {
    64  			return false
    65  		}
    66  		ia += inca
    67  		ib += incb
    68  	}
    69  	return true
    70  }
    71  
    72  func dSliceEqual(a, b []float64) bool {
    73  	if len(a) != len(b) {
    74  		return false
    75  	}
    76  	for i := range a {
    77  		if !dTolEqual(a[i], b[i]) {
    78  			return false
    79  		}
    80  	}
    81  	return true
    82  }
    83  
    84  func dCopyTwoTmp(x, xTmp, y, yTmp []float64) {
    85  	if len(x) != len(xTmp) {
    86  		panic("x size mismatch")
    87  	}
    88  	if len(y) != len(yTmp) {
    89  		panic("y size mismatch")
    90  	}
    91  	copy(xTmp, x)
    92  	copy(yTmp, y)
    93  }
    94  
    95  // returns true if the function panics
    96  func panics(f func()) (b bool) {
    97  	defer func() {
    98  		err := recover()
    99  		if err != nil {
   100  			b = true
   101  		}
   102  	}()
   103  	f()
   104  	return
   105  }
   106  
   107  func testpanics(f func(), name string, t *testing.T) {
   108  	b := panics(f)
   109  	if !b {
   110  		t.Errorf("%v should panic and does not", name)
   111  	}
   112  }
   113  
   114  func sliceOfSliceCopy(a [][]float64) [][]float64 {
   115  	n := make([][]float64, len(a))
   116  	for i := range a {
   117  		n[i] = make([]float64, len(a[i]))
   118  		copy(n[i], a[i])
   119  	}
   120  	return n
   121  }
   122  
   123  func sliceCopy(a []float64) []float64 {
   124  	n := make([]float64, len(a))
   125  	copy(n, a)
   126  	return n
   127  }
   128  
   129  func flatten(a [][]float64) []float64 {
   130  	if len(a) == 0 {
   131  		return nil
   132  	}
   133  	m := len(a)
   134  	n := len(a[0])
   135  	s := make([]float64, m*n)
   136  	for i := 0; i < m; i++ {
   137  		for j := 0; j < n; j++ {
   138  			s[i*n+j] = a[i][j]
   139  		}
   140  	}
   141  	return s
   142  }
   143  
   144  func unflatten(a []float64, m, n int) [][]float64 {
   145  	s := make([][]float64, m)
   146  	for i := 0; i < m; i++ {
   147  		s[i] = make([]float64, n)
   148  		for j := 0; j < n; j++ {
   149  			s[i][j] = a[i*n+j]
   150  		}
   151  	}
   152  	return s
   153  }
   154  
   155  // flattenTriangular turns the upper or lower triangle of a dense slice of slice
   156  // into a single slice with packed storage. a must be a square matrix.
   157  func flattenTriangular(a [][]float64, ul blas.Uplo) []float64 {
   158  	m := len(a)
   159  	aFlat := make([]float64, m*(m+1)/2)
   160  	var k int
   161  	if ul == blas.Upper {
   162  		for i := 0; i < m; i++ {
   163  			k += copy(aFlat[k:], a[i][i:])
   164  		}
   165  		return aFlat
   166  	}
   167  	for i := 0; i < m; i++ {
   168  		k += copy(aFlat[k:], a[i][:i+1])
   169  	}
   170  	return aFlat
   171  }
   172  
   173  // flattenBanded turns a dense banded slice of slice into the compact banded matrix format
   174  func flattenBanded(a [][]float64, ku, kl int) []float64 {
   175  	m := len(a)
   176  	n := len(a[0])
   177  	if ku < 0 || kl < 0 {
   178  		panic("testblas: negative band length")
   179  	}
   180  	nRows := m
   181  	nCols := (ku + kl + 1)
   182  	aflat := make([]float64, nRows*nCols)
   183  	for i := range aflat {
   184  		aflat[i] = math.NaN()
   185  	}
   186  	// loop over the rows, and then the bands
   187  	// elements in the ith row stay in the ith row
   188  	// order in bands is kept
   189  	for i := 0; i < nRows; i++ {
   190  		min := -kl
   191  		if i-kl < 0 {
   192  			min = -i
   193  		}
   194  		max := ku
   195  		if i+ku >= n {
   196  			max = n - i - 1
   197  		}
   198  		for j := min; j <= max; j++ {
   199  			col := kl + j
   200  			aflat[i*nCols+col] = a[i][i+j]
   201  		}
   202  	}
   203  	return aflat
   204  }
   205  
   206  // makeIncremented takes a float64 slice with inc == 1 and makes an incremented version
   207  // and adds extra values on the end
   208  func makeIncremented(x []float64, inc int, extra int) []float64 {
   209  	if inc == 0 {
   210  		panic("zero inc")
   211  	}
   212  	absinc := inc
   213  	if absinc < 0 {
   214  		absinc = -inc
   215  	}
   216  	xcopy := make([]float64, len(x))
   217  	if inc > 0 {
   218  		copy(xcopy, x)
   219  	} else {
   220  		for i := 0; i < len(x); i++ {
   221  			xcopy[i] = x[len(x)-i-1]
   222  		}
   223  	}
   224  
   225  	// don't use NaN because it makes comparison hard
   226  	// Do use a weird unique value for easier debugging
   227  	counter := 100.0
   228  	var xnew []float64
   229  	for i, v := range xcopy {
   230  		xnew = append(xnew, v)
   231  		if i != len(x)-1 {
   232  			for j := 0; j < absinc-1; j++ {
   233  				xnew = append(xnew, counter)
   234  				counter++
   235  			}
   236  		}
   237  	}
   238  	for i := 0; i < extra; i++ {
   239  		xnew = append(xnew, counter)
   240  		counter++
   241  	}
   242  	return xnew
   243  }
   244  
   245  // makeIncremented32 takes a float32 slice with inc == 1 and makes an incremented version
   246  // and adds extra values on the end
   247  func makeIncremented32(x []float32, inc int, extra int) []float32 {
   248  	if inc == 0 {
   249  		panic("zero inc")
   250  	}
   251  	absinc := inc
   252  	if absinc < 0 {
   253  		absinc = -inc
   254  	}
   255  	xcopy := make([]float32, len(x))
   256  	if inc > 0 {
   257  		copy(xcopy, x)
   258  	} else {
   259  		for i := 0; i < len(x); i++ {
   260  			xcopy[i] = x[len(x)-i-1]
   261  		}
   262  	}
   263  
   264  	// don't use NaN because it makes comparison hard
   265  	// Do use a weird unique value for easier debugging
   266  	var counter float32 = 100.0
   267  	var xnew []float32
   268  	for i, v := range xcopy {
   269  		xnew = append(xnew, v)
   270  		if i != len(x)-1 {
   271  			for j := 0; j < absinc-1; j++ {
   272  				xnew = append(xnew, counter)
   273  				counter++
   274  			}
   275  		}
   276  	}
   277  	for i := 0; i < extra; i++ {
   278  		xnew = append(xnew, counter)
   279  		counter++
   280  	}
   281  	return xnew
   282  }
   283  
   284  func abs(x int) int {
   285  	if x < 0 {
   286  		return -x
   287  	}
   288  	return x
   289  }
   290  
   291  func allPairs(x, y []int) [][2]int {
   292  	var p [][2]int
   293  	for _, v0 := range x {
   294  		for _, v1 := range y {
   295  			p = append(p, [2]int{v0, v1})
   296  		}
   297  	}
   298  	return p
   299  }
   300  
   301  func sameFloat64(a, b float64) bool {
   302  	return a == b || math.IsNaN(a) && math.IsNaN(b)
   303  }
   304  
   305  func sameComplex128(x, y complex128) bool {
   306  	return sameFloat64(real(x), real(y)) && sameFloat64(imag(x), imag(y))
   307  }
   308  
   309  func zsame(x, y []complex128) bool {
   310  	if len(x) != len(y) {
   311  		return false
   312  	}
   313  	for i, v := range x {
   314  		w := y[i]
   315  		if !sameComplex128(v, w) {
   316  			return false
   317  		}
   318  	}
   319  	return true
   320  }
   321  
   322  // zSameAtNonstrided returns whether elements at non-stride positions of vectors
   323  // x and y are same.
   324  func zSameAtNonstrided(x, y []complex128, inc int) bool {
   325  	if len(x) != len(y) {
   326  		return false
   327  	}
   328  	if inc < 0 {
   329  		inc = -inc
   330  	}
   331  	for i, v := range x {
   332  		if i%inc == 0 {
   333  			continue
   334  		}
   335  		w := y[i]
   336  		if !sameComplex128(v, w) {
   337  			return false
   338  		}
   339  	}
   340  	return true
   341  }
   342  
   343  // zEqualApproxAtStrided returns whether elements at stride positions of vectors
   344  // x and y are approximately equal within tol.
   345  func zEqualApproxAtStrided(x, y []complex128, inc int, tol float64) bool {
   346  	if len(x) != len(y) {
   347  		return false
   348  	}
   349  	if inc < 0 {
   350  		inc = -inc
   351  	}
   352  	for i := 0; i < len(x); i += inc {
   353  		v := x[i]
   354  		w := y[i]
   355  		if !(cmplx.Abs(v-w) <= tol) {
   356  			return false
   357  		}
   358  	}
   359  	return true
   360  }
   361  
   362  func makeZVector(data []complex128, inc int) []complex128 {
   363  	if inc == 0 {
   364  		panic("bad test")
   365  	}
   366  	if len(data) == 0 {
   367  		return nil
   368  	}
   369  	inc = abs(inc)
   370  	x := make([]complex128, (len(data)-1)*inc+1)
   371  	for i := range x {
   372  		x[i] = znan
   373  	}
   374  	for i, v := range data {
   375  		x[i*inc] = v
   376  	}
   377  	return x
   378  }
   379  
   380  func makeZGeneral(data []complex128, m, n int, ld int) []complex128 {
   381  	if m < 0 || n < 0 {
   382  		panic("bad test")
   383  	}
   384  	if data != nil && len(data) != m*n {
   385  		panic("bad test")
   386  	}
   387  	if ld < max(1, n) {
   388  		panic("bad test")
   389  	}
   390  	if m == 0 || n == 0 {
   391  		return nil
   392  	}
   393  	a := make([]complex128, (m-1)*ld+n)
   394  	for i := range a {
   395  		a[i] = znan
   396  	}
   397  	if data != nil {
   398  		for i := 0; i < m; i++ {
   399  			copy(a[i*ld:i*ld+n], data[i*n:i*n+n])
   400  		}
   401  	}
   402  	return a
   403  }
   404  
   405  // zPack returns the uplo triangle of an n×n matrix A in packed format.
   406  func zPack(uplo blas.Uplo, n int, a []complex128, lda int) []complex128 {
   407  	if n == 0 {
   408  		return nil
   409  	}
   410  	ap := make([]complex128, n*(n+1)/2)
   411  	var ii int
   412  	if uplo == blas.Upper {
   413  		for i := 0; i < n; i++ {
   414  			for j := i; j < n; j++ {
   415  				ap[ii] = a[i*lda+j]
   416  				ii++
   417  			}
   418  		}
   419  	} else {
   420  		for i := 0; i < n; i++ {
   421  			for j := 0; j <= i; j++ {
   422  				ap[ii] = a[i*lda+j]
   423  				ii++
   424  			}
   425  		}
   426  	}
   427  	return ap
   428  }
   429  
   430  // zUnpackAsHermitian returns an n×n general Hermitian matrix (with stride n)
   431  // whose packed uplo triangle is stored on entry in ap.
   432  func zUnpackAsHermitian(uplo blas.Uplo, n int, ap []complex128) []complex128 {
   433  	if n == 0 {
   434  		return nil
   435  	}
   436  	a := make([]complex128, n*n)
   437  	lda := n
   438  	var ii int
   439  	if uplo == blas.Upper {
   440  		for i := 0; i < n; i++ {
   441  			for j := i; j < n; j++ {
   442  				a[i*lda+j] = ap[ii]
   443  				if i != j {
   444  					a[j*lda+i] = cmplx.Conj(ap[ii])
   445  				}
   446  				ii++
   447  			}
   448  		}
   449  	} else {
   450  		for i := 0; i < n; i++ {
   451  			for j := 0; j <= i; j++ {
   452  				a[i*lda+j] = ap[ii]
   453  				if i != j {
   454  					a[j*lda+i] = cmplx.Conj(ap[ii])
   455  				}
   456  				ii++
   457  			}
   458  		}
   459  	}
   460  	return a
   461  }
   462  
   463  // zPackBand returns the (kL+1+kU) band of an m×n general matrix A in band
   464  // matrix format with ldab stride. Out-of-range elements are filled with NaN.
   465  func zPackBand(kL, kU, ldab int, m, n int, a []complex128, lda int) []complex128 {
   466  	if m == 0 || n == 0 {
   467  		return nil
   468  	}
   469  	nRow := min(m, n+kL)
   470  	ab := make([]complex128, (nRow-1)*ldab+kL+1+kU)
   471  	for i := range ab {
   472  		ab[i] = znan
   473  	}
   474  	for i := 0; i < m; i++ {
   475  		off := max(0, kL-i)
   476  		var k int
   477  		for j := max(0, i-kL); j < min(n, i+kU+1); j++ {
   478  			ab[i*ldab+off+k] = a[i*lda+j]
   479  			k++
   480  		}
   481  	}
   482  	return ab
   483  }
   484  
   485  // zPackTriBand returns in band matrix format the (k+1) band in the uplo
   486  // triangle of an n×n matrix A. Out-of-range elements are filled with NaN.
   487  func zPackTriBand(k, ldab int, uplo blas.Uplo, n int, a []complex128, lda int) []complex128 {
   488  	if n == 0 {
   489  		return nil
   490  	}
   491  	ab := make([]complex128, (n-1)*ldab+k+1)
   492  	for i := range ab {
   493  		ab[i] = znan
   494  	}
   495  	if uplo == blas.Upper {
   496  		for i := 0; i < n; i++ {
   497  			var k int
   498  			for j := i; j < min(n, i+k+1); j++ {
   499  				ab[i*ldab+k] = a[i*lda+j]
   500  				k++
   501  			}
   502  		}
   503  	} else {
   504  		for i := 0; i < n; i++ {
   505  			off := max(0, k-i)
   506  			var kk int
   507  			for j := max(0, i-k); j <= i; j++ {
   508  				ab[i*ldab+off+kk] = a[i*lda+j]
   509  				kk++
   510  			}
   511  		}
   512  	}
   513  	return ab
   514  }
   515  
   516  // zEqualApprox returns whether the slices a and b are approximately equal.
   517  func zEqualApprox(a, b []complex128, tol float64) bool {
   518  	if len(a) != len(b) {
   519  		panic("mismatched slice length")
   520  	}
   521  	for i, ai := range a {
   522  		if !scalar.EqualWithinAbs(cmplx.Abs(ai), cmplx.Abs(b[i]), tol) {
   523  			return false
   524  		}
   525  	}
   526  	return true
   527  }
   528  
   529  // rndComplex128 returns a complex128 with random components.
   530  func rndComplex128(rnd *rand.Rand) complex128 {
   531  	return complex(rnd.NormFloat64(), rnd.NormFloat64())
   532  }
   533  
   534  // zmm returns the result of one of the matrix-matrix operations
   535  //
   536  //	alpha * op(A) * op(B) + beta * C
   537  //
   538  // where op(X) is one of
   539  //
   540  //	op(X) = X  or  op(X) = Xᵀ  or  op(X) = Xᴴ,
   541  //
   542  // alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
   543  // op(B) a k×n matrix and C an m×n matrix.
   544  //
   545  // The returned slice is newly allocated, has the same length as c and the
   546  // matrix it represents has the stride ldc. Out-of-range elements are equal to
   547  // those of C to ease comparison of results from BLAS Level 3 functions.
   548  func zmm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) []complex128 {
   549  	r := make([]complex128, len(c))
   550  	copy(r, c)
   551  	for i := 0; i < m; i++ {
   552  		for j := 0; j < n; j++ {
   553  			r[i*ldc+j] = 0
   554  		}
   555  	}
   556  	switch tA {
   557  	case blas.NoTrans:
   558  		switch tB {
   559  		case blas.NoTrans:
   560  			for i := 0; i < m; i++ {
   561  				for j := 0; j < n; j++ {
   562  					for l := 0; l < k; l++ {
   563  						r[i*ldc+j] += a[i*lda+l] * b[l*ldb+j]
   564  					}
   565  				}
   566  			}
   567  		case blas.Trans:
   568  			for i := 0; i < m; i++ {
   569  				for j := 0; j < n; j++ {
   570  					for l := 0; l < k; l++ {
   571  						r[i*ldc+j] += a[i*lda+l] * b[j*ldb+l]
   572  					}
   573  				}
   574  			}
   575  		case blas.ConjTrans:
   576  			for i := 0; i < m; i++ {
   577  				for j := 0; j < n; j++ {
   578  					for l := 0; l < k; l++ {
   579  						r[i*ldc+j] += a[i*lda+l] * cmplx.Conj(b[j*ldb+l])
   580  					}
   581  				}
   582  			}
   583  		}
   584  	case blas.Trans:
   585  		switch tB {
   586  		case blas.NoTrans:
   587  			for i := 0; i < m; i++ {
   588  				for j := 0; j < n; j++ {
   589  					for l := 0; l < k; l++ {
   590  						r[i*ldc+j] += a[l*lda+i] * b[l*ldb+j]
   591  					}
   592  				}
   593  			}
   594  		case blas.Trans:
   595  			for i := 0; i < m; i++ {
   596  				for j := 0; j < n; j++ {
   597  					for l := 0; l < k; l++ {
   598  						r[i*ldc+j] += a[l*lda+i] * b[j*ldb+l]
   599  					}
   600  				}
   601  			}
   602  		case blas.ConjTrans:
   603  			for i := 0; i < m; i++ {
   604  				for j := 0; j < n; j++ {
   605  					for l := 0; l < k; l++ {
   606  						r[i*ldc+j] += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
   607  					}
   608  				}
   609  			}
   610  		}
   611  	case blas.ConjTrans:
   612  		switch tB {
   613  		case blas.NoTrans:
   614  			for i := 0; i < m; i++ {
   615  				for j := 0; j < n; j++ {
   616  					for l := 0; l < k; l++ {
   617  						r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
   618  					}
   619  				}
   620  			}
   621  		case blas.Trans:
   622  			for i := 0; i < m; i++ {
   623  				for j := 0; j < n; j++ {
   624  					for l := 0; l < k; l++ {
   625  						r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
   626  					}
   627  				}
   628  			}
   629  		case blas.ConjTrans:
   630  			for i := 0; i < m; i++ {
   631  				for j := 0; j < n; j++ {
   632  					for l := 0; l < k; l++ {
   633  						r[i*ldc+j] += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
   634  					}
   635  				}
   636  			}
   637  		}
   638  	}
   639  	for i := 0; i < m; i++ {
   640  		for j := 0; j < n; j++ {
   641  			r[i*ldc+j] = alpha * r[i*ldc+j]
   642  			if beta != 0 {
   643  				r[i*ldc+j] += beta * c[i*ldc+j]
   644  			}
   645  		}
   646  	}
   647  	return r
   648  }
   649  
   650  // transString returns a string representation of blas.Transpose.
   651  func transString(t blas.Transpose) string {
   652  	switch t {
   653  	case blas.NoTrans:
   654  		return "NoTrans"
   655  	case blas.Trans:
   656  		return "Trans"
   657  	case blas.ConjTrans:
   658  		return "ConjTrans"
   659  	}
   660  	return "unknown trans"
   661  }
   662  
   663  // uploString returns a string representation of blas.Uplo.
   664  func uploString(uplo blas.Uplo) string {
   665  	switch uplo {
   666  	case blas.Lower:
   667  		return "Lower"
   668  	case blas.Upper:
   669  		return "Upper"
   670  	}
   671  	return "unknown uplo"
   672  }
   673  
   674  // sideString returns a string representation of blas.Side.
   675  func sideString(side blas.Side) string {
   676  	switch side {
   677  	case blas.Left:
   678  		return "Left"
   679  	case blas.Right:
   680  		return "Right"
   681  	}
   682  	return "unknown side"
   683  }
   684  
   685  // diagString returns a string representation of blas.Diag.
   686  func diagString(diag blas.Diag) string {
   687  	switch diag {
   688  	case blas.Unit:
   689  		return "Unit"
   690  	case blas.NonUnit:
   691  		return "NonUnit"
   692  	}
   693  	return "unknown diag"
   694  }
   695  
   696  // zSameLowerTri returns whether n×n matrices A and B are same under the diagonal.
   697  func zSameLowerTri(n int, a []complex128, lda int, b []complex128, ldb int) bool {
   698  	for i := 1; i < n; i++ {
   699  		for j := 0; j < i; j++ {
   700  			aij := a[i*lda+j]
   701  			bij := b[i*ldb+j]
   702  			if !sameComplex128(aij, bij) {
   703  				return false
   704  			}
   705  		}
   706  	}
   707  	return true
   708  }
   709  
   710  // zSameUpperTri returns whether n×n matrices A and B are same above the diagonal.
   711  func zSameUpperTri(n int, a []complex128, lda int, b []complex128, ldb int) bool {
   712  	for i := 0; i < n-1; i++ {
   713  		for j := i + 1; j < n; j++ {
   714  			aij := a[i*lda+j]
   715  			bij := b[i*ldb+j]
   716  			if !sameComplex128(aij, bij) {
   717  				return false
   718  			}
   719  		}
   720  	}
   721  	return true
   722  }