gorgonia.org/gorgonia@v0.9.17/blase/blas_test.go (about)

     1  package blase
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"gonum.org/v1/gonum/blas"
     9  )
    10  
    11  const EPSILON float64 = 1e-10
    12  
    13  func floatEquals(a, b float64) bool {
    14  	if (a-b) < EPSILON && (b-a) < EPSILON {
    15  		return true
    16  	}
    17  	return false
    18  }
    19  
    20  func floatsEqual(a, b []float64) bool {
    21  	if len(a) != len(b) {
    22  		return false
    23  	}
    24  
    25  	for i, v := range a {
    26  		if !floatEquals(v, b[i]) {
    27  			return false
    28  		}
    29  	}
    30  	return true
    31  }
    32  
    33  func randomFloat64(r, c int) []float64 {
    34  	retVal := make([]float64, r*c)
    35  	for i := range retVal {
    36  		retVal[i] = rand.Float64()
    37  	}
    38  	return retVal
    39  }
    40  
    41  func testDGEMM(t *testing.T, whichblas *context) (C, correct []float64) {
    42  	A := randomFloat64(2, 2)
    43  	B := randomFloat64(2, 3)
    44  
    45  	tA := blas.NoTrans
    46  	tB := blas.NoTrans
    47  	m := 2
    48  	n := 3
    49  	k := 2
    50  	alpha := 1.0
    51  	lda := 2
    52  	ldb := 3
    53  	beta := 0.0
    54  	ldc := 3
    55  
    56  	C = make([]float64, 2*3)
    57  	correct = []float64{
    58  		A[0]*B[0] + A[1]*B[3],
    59  		A[0]*B[1] + A[1]*B[4],
    60  		A[0]*B[2] + A[1]*B[5],
    61  
    62  		A[2]*B[0] + A[3]*B[3],
    63  		A[2]*B[1] + A[3]*B[4],
    64  		A[2]*B[2] + A[3]*B[5],
    65  	}
    66  
    67  	whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
    68  	return
    69  }
    70  
    71  func TestQueue(t *testing.T) {
    72  	assert := assert.New(t)
    73  	whichblas := Implementation()
    74  
    75  	workAvailable := whichblas.WorkAvailable()
    76  	go func() {
    77  		for range workAvailable {
    78  			whichblas.DoWork()
    79  		}
    80  	}()
    81  
    82  	var corrects [][]float64
    83  	var Cs [][]float64
    84  	for i := 0; i < 4; i++ {
    85  		C, correct := testDGEMM(t, whichblas)
    86  		Cs = append(Cs, C)
    87  		corrects = append(corrects, correct)
    88  
    89  		if i < workbufLen {
    90  			assert.True(floatsEqual(make([]float64, 6), C))
    91  		}
    92  	}
    93  	whichblas.DoWork()
    94  
    95  	for i, C := range Cs {
    96  		assert.True(floatsEqual(corrects[i], C))
    97  	}
    98  }