gonum.org/v1/gonum@v0.14.0/blas/testblas/dgemm.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "testing" 9 10 "gonum.org/v1/gonum/blas" 11 ) 12 13 type Dgemmer interface { 14 Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) 15 } 16 17 type DgemmCase struct { 18 m, n, k int 19 alpha, beta float64 20 a [][]float64 21 b [][]float64 22 c [][]float64 23 ans [][]float64 24 } 25 26 var DgemmCases = []DgemmCase{ 27 28 { 29 m: 4, 30 n: 3, 31 k: 2, 32 alpha: 2, 33 beta: 0.5, 34 a: [][]float64{ 35 {1, 2}, 36 {4, 5}, 37 {7, 8}, 38 {10, 11}, 39 }, 40 b: [][]float64{ 41 {1, 5, 6}, 42 {5, -8, 8}, 43 }, 44 c: [][]float64{ 45 {4, 8, -9}, 46 {12, 16, -8}, 47 {1, 5, 15}, 48 {-3, -4, 7}, 49 }, 50 ans: [][]float64{ 51 {24, -18, 39.5}, 52 {64, -32, 124}, 53 {94.5, -55.5, 219.5}, 54 {128.5, -78, 299.5}, 55 }, 56 }, 57 { 58 m: 4, 59 n: 2, 60 k: 3, 61 alpha: 2, 62 beta: 0.5, 63 a: [][]float64{ 64 {1, 2, 3}, 65 {4, 5, 6}, 66 {7, 8, 9}, 67 {10, 11, 12}, 68 }, 69 b: [][]float64{ 70 {1, 5}, 71 {5, -8}, 72 {6, 2}, 73 }, 74 c: [][]float64{ 75 {4, 8}, 76 {12, 16}, 77 {1, 5}, 78 {-3, -4}, 79 }, 80 ans: [][]float64{ 81 {60, -6}, 82 {136, -8}, 83 {202.5, -19.5}, 84 {272.5, -30}, 85 }, 86 }, 87 { 88 m: 3, 89 n: 2, 90 k: 4, 91 alpha: 2, 92 beta: 0.5, 93 a: [][]float64{ 94 {1, 2, 3, 4}, 95 {4, 5, 6, 7}, 96 {8, 9, 10, 11}, 97 }, 98 b: [][]float64{ 99 {1, 5}, 100 {5, -8}, 101 {6, 2}, 102 {8, 10}, 103 }, 104 c: [][]float64{ 105 {4, 8}, 106 {12, 16}, 107 {9, -10}, 108 }, 109 ans: [][]float64{ 110 {124, 74}, 111 {248, 132}, 112 {406.5, 191}, 113 }, 114 }, 115 { 116 m: 3, 117 n: 4, 118 k: 2, 119 alpha: 2, 120 beta: 0.5, 121 a: [][]float64{ 122 {1, 2}, 123 {4, 5}, 124 {8, 9}, 125 }, 126 b: [][]float64{ 127 {1, 5, 2, 1}, 128 {5, -8, 2, 1}, 129 }, 130 c: [][]float64{ 131 {4, 8, 2, 2}, 132 {12, 16, 8, 9}, 133 {9, -10, 10, 10}, 134 }, 135 ans: [][]float64{ 136 {24, -18, 13, 7}, 137 {64, -32, 40, 22.5}, 138 {110.5, -69, 73, 39}, 139 }, 140 }, 141 { 142 m: 2, 143 n: 4, 144 k: 3, 145 alpha: 2, 146 beta: 0.5, 147 a: [][]float64{ 148 {1, 2, 3}, 149 {4, 5, 6}, 150 }, 151 b: [][]float64{ 152 {1, 5, 8, 8}, 153 {5, -8, 9, 10}, 154 {6, 2, -3, 2}, 155 }, 156 c: [][]float64{ 157 {4, 8, 7, 8}, 158 {12, 16, -2, 6}, 159 }, 160 ans: [][]float64{ 161 {60, -6, 37.5, 72}, 162 {136, -8, 117, 191}, 163 }, 164 }, 165 { 166 m: 2, 167 n: 3, 168 k: 4, 169 alpha: 2, 170 beta: 0.5, 171 a: [][]float64{ 172 {1, 2, 3, 4}, 173 {4, 5, 6, 7}, 174 }, 175 b: [][]float64{ 176 {1, 5, 8}, 177 {5, -8, 9}, 178 {6, 2, -3}, 179 {8, 10, 2}, 180 }, 181 c: [][]float64{ 182 {4, 8, 1}, 183 {12, 16, 6}, 184 }, 185 ans: [][]float64{ 186 {124, 74, 50.5}, 187 {248, 132, 149}, 188 }, 189 }, 190 { 191 m: 2, 192 n: 3, 193 k: 4, 194 alpha: 2, 195 beta: 0, 196 a: [][]float64{ 197 {1, 2, 3, 4}, 198 {4, 5, 6, 7}, 199 }, 200 b: [][]float64{ 201 {1, 5, 8}, 202 {5, -8, 9}, 203 {6, 2, -3}, 204 {8, 10, 2}, 205 }, 206 c: [][]float64{ 207 {4, 8, 1}, 208 {12, 16, 6}, 209 }, 210 ans: [][]float64{ 211 {122, 70, 50}, 212 {242, 124, 146}, 213 }, 214 }, 215 } 216 217 // assumes [][]float64 is actually a matrix 218 func transpose(a [][]float64) [][]float64 { 219 b := make([][]float64, len(a[0])) 220 for i := range b { 221 b[i] = make([]float64, len(a)) 222 for j := range b[i] { 223 b[i][j] = a[j][i] 224 } 225 } 226 return b 227 } 228 229 func TestDgemm(t *testing.T, blasser Dgemmer) { 230 for i, test := range DgemmCases { 231 // Test that it passes row major 232 dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans, 233 test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans) 234 // Try with A transposed 235 dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans, 236 test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans) 237 // Try with B transposed 238 dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans, 239 test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans) 240 // Try with both transposed 241 dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans, 242 test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans) 243 } 244 } 245 246 func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int, 247 alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) { 248 249 aFlat := flatten(a) 250 aCopy := flatten(a) 251 bFlat := flatten(b) 252 bCopy := flatten(b) 253 cFlat := flatten(c) 254 ansFlat := flatten(ans) 255 lda := len(a[0]) 256 ldb := len(b[0]) 257 ldc := len(c[0]) 258 259 // Compute the matrix multiplication 260 blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc) 261 262 if !dSliceEqual(aFlat, aCopy) { 263 t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name) 264 } 265 if !dSliceEqual(bFlat, bCopy) { 266 t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name) 267 } 268 269 if !dSliceTolEqual(ansFlat, cFlat) { 270 t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat) 271 } 272 // TODO: Need to add a sub-slice test where don't use up full matrix 273 }