gonum.org/v1/gonum@v0.14.0/blas/gonum/pardgemm_test.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/floats" 14 ) 15 16 func TestDgemmParallel(t *testing.T) { 17 rnd := rand.New(rand.NewSource(1)) 18 for i, test := range []struct { 19 m int 20 n int 21 k int 22 alpha float64 23 tA blas.Transpose 24 tB blas.Transpose 25 }{ 26 { 27 m: 3, 28 n: 4, 29 k: 2, 30 alpha: 2.5, 31 tA: blas.NoTrans, 32 tB: blas.NoTrans, 33 }, 34 { 35 m: blockSize*2 + 5, 36 n: 3, 37 k: 2, 38 alpha: 2.5, 39 tA: blas.NoTrans, 40 tB: blas.NoTrans, 41 }, 42 { 43 m: 3, 44 n: blockSize * 2, 45 k: 2, 46 alpha: 2.5, 47 tA: blas.NoTrans, 48 tB: blas.NoTrans, 49 }, 50 { 51 m: 2, 52 n: 3, 53 k: blockSize*3 - 2, 54 alpha: 2.5, 55 tA: blas.NoTrans, 56 tB: blas.NoTrans, 57 }, 58 { 59 m: blockSize * minParBlock, 60 n: 3, 61 k: 2, 62 alpha: 2.5, 63 tA: blas.NoTrans, 64 tB: blas.NoTrans, 65 }, 66 { 67 m: 3, 68 n: blockSize * minParBlock, 69 k: 2, 70 alpha: 2.5, 71 tA: blas.NoTrans, 72 tB: blas.NoTrans, 73 }, 74 { 75 m: 2, 76 n: 3, 77 k: blockSize * minParBlock, 78 alpha: 2.5, 79 tA: blas.NoTrans, 80 tB: blas.NoTrans, 81 }, 82 { 83 m: blockSize*minParBlock + 1, 84 n: blockSize * minParBlock, 85 k: 3, 86 alpha: 2.5, 87 tA: blas.NoTrans, 88 tB: blas.NoTrans, 89 }, 90 { 91 m: 3, 92 n: blockSize*minParBlock + 2, 93 k: blockSize * 3, 94 alpha: 2.5, 95 tA: blas.NoTrans, 96 tB: blas.NoTrans, 97 }, 98 { 99 m: blockSize * minParBlock, 100 n: 3, 101 k: blockSize * minParBlock, 102 alpha: 2.5, 103 tA: blas.NoTrans, 104 tB: blas.NoTrans, 105 }, 106 { 107 m: blockSize * minParBlock, 108 n: blockSize * minParBlock, 109 k: blockSize * 3, 110 alpha: 2.5, 111 tA: blas.NoTrans, 112 tB: blas.NoTrans, 113 }, 114 { 115 m: blockSize + blockSize/2, 116 n: blockSize + blockSize/2, 117 k: blockSize + blockSize/2, 118 alpha: 2.5, 119 tA: blas.NoTrans, 120 tB: blas.NoTrans, 121 }, 122 } { 123 testMatchParallelSerial(t, rnd, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha) 124 testMatchParallelSerial(t, rnd, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha) 125 testMatchParallelSerial(t, rnd, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha) 126 testMatchParallelSerial(t, rnd, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha) 127 } 128 } 129 130 func testMatchParallelSerial(t *testing.T, rnd *rand.Rand, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) { 131 var ( 132 rowA, colA int 133 rowB, colB int 134 ) 135 if tA == blas.NoTrans { 136 rowA = m 137 colA = k 138 } else { 139 rowA = k 140 colA = m 141 } 142 if tB == blas.NoTrans { 143 rowB = k 144 colB = n 145 } else { 146 rowB = n 147 colB = k 148 } 149 150 lda := colA 151 a := randmat(rowA, colA, lda, rnd) 152 aCopy := make([]float64, len(a)) 153 copy(aCopy, a) 154 155 ldb := colB 156 b := randmat(rowB, colB, ldb, rnd) 157 bCopy := make([]float64, len(b)) 158 copy(bCopy, b) 159 160 ldc := n 161 c := randmat(m, n, ldc, rnd) 162 want := make([]float64, len(c)) 163 copy(want, c) 164 165 dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a, lda, b, ldb, want, ldc, alpha) 166 dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 167 168 if !floats.Equal(a, aCopy) { 169 t.Errorf("Case %v: a changed during call to dgemmParallel", i) 170 } 171 if !floats.Equal(b, bCopy) { 172 t.Errorf("Case %v: b changed during call to dgemmParallel", i) 173 } 174 if !floats.EqualApprox(c, want, 1e-12) { 175 t.Errorf("Case %v: answer not equal parallel and serial", i) 176 } 177 } 178 179 func randmat(r, c, stride int, rnd *rand.Rand) []float64 { 180 data := make([]float64, r*stride+c) 181 for i := range data { 182 data[i] = rnd.NormFloat64() 183 } 184 return data 185 }