gorgonia.org/gorgonia@v0.9.17/blase/blas_test.go (about) 1 package blase 2 3 import ( 4 "math/rand" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 "gonum.org/v1/gonum/blas" 9 ) 10 11 const EPSILON float64 = 1e-10 12 13 func floatEquals(a, b float64) bool { 14 if (a-b) < EPSILON && (b-a) < EPSILON { 15 return true 16 } 17 return false 18 } 19 20 func floatsEqual(a, b []float64) bool { 21 if len(a) != len(b) { 22 return false 23 } 24 25 for i, v := range a { 26 if !floatEquals(v, b[i]) { 27 return false 28 } 29 } 30 return true 31 } 32 33 func randomFloat64(r, c int) []float64 { 34 retVal := make([]float64, r*c) 35 for i := range retVal { 36 retVal[i] = rand.Float64() 37 } 38 return retVal 39 } 40 41 func testDGEMM(t *testing.T, whichblas *context) (C, correct []float64) { 42 A := randomFloat64(2, 2) 43 B := randomFloat64(2, 3) 44 45 tA := blas.NoTrans 46 tB := blas.NoTrans 47 m := 2 48 n := 3 49 k := 2 50 alpha := 1.0 51 lda := 2 52 ldb := 3 53 beta := 0.0 54 ldc := 3 55 56 C = make([]float64, 2*3) 57 correct = []float64{ 58 A[0]*B[0] + A[1]*B[3], 59 A[0]*B[1] + A[1]*B[4], 60 A[0]*B[2] + A[1]*B[5], 61 62 A[2]*B[0] + A[3]*B[3], 63 A[2]*B[1] + A[3]*B[4], 64 A[2]*B[2] + A[3]*B[5], 65 } 66 67 whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) 68 return 69 } 70 71 func TestQueue(t *testing.T) { 72 assert := assert.New(t) 73 whichblas := Implementation() 74 75 workAvailable := whichblas.WorkAvailable() 76 go func() { 77 for range workAvailable { 78 whichblas.DoWork() 79 } 80 }() 81 82 var corrects [][]float64 83 var Cs [][]float64 84 for i := 0; i < 4; i++ { 85 C, correct := testDGEMM(t, whichblas) 86 Cs = append(Cs, C) 87 corrects = append(corrects, correct) 88 89 if i < workbufLen { 90 assert.True(floatsEqual(make([]float64, 6), C)) 91 } 92 } 93 whichblas.DoWork() 94 95 for i, C := range Cs { 96 assert.True(floatsEqual(corrects[i], C)) 97 } 98 }