github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/testblas/dgemm.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testblas 6 7 import ( 8 "testing" 9 10 "github.com/jingcheng-WU/gonum/blas" 11 ) 12 13 type Dgemmer interface { 14 Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) 15 } 16 17 type DgemmCase struct { 18 m, n, k int 19 alpha, beta float64 20 a [][]float64 21 b [][]float64 22 c [][]float64 23 ans [][]float64 24 } 25 26 var DgemmCases = []DgemmCase{ 27 28 { 29 m: 4, 30 n: 3, 31 k: 2, 32 alpha: 2, 33 beta: 0.5, 34 a: [][]float64{ 35 {1, 2}, 36 {4, 5}, 37 {7, 8}, 38 {10, 11}, 39 }, 40 b: [][]float64{ 41 {1, 5, 6}, 42 {5, -8, 8}, 43 }, 44 c: [][]float64{ 45 {4, 8, -9}, 46 {12, 16, -8}, 47 {1, 5, 15}, 48 {-3, -4, 7}, 49 }, 50 ans: [][]float64{ 51 {24, -18, 39.5}, 52 {64, -32, 124}, 53 {94.5, -55.5, 219.5}, 54 {128.5, -78, 299.5}, 55 }, 56 }, 57 { 58 m: 4, 59 n: 2, 60 k: 3, 61 alpha: 2, 62 beta: 0.5, 63 a: [][]float64{ 64 {1, 2, 3}, 65 {4, 5, 6}, 66 {7, 8, 9}, 67 {10, 11, 12}, 68 }, 69 b: [][]float64{ 70 {1, 5}, 71 {5, -8}, 72 {6, 2}, 73 }, 74 c: [][]float64{ 75 {4, 8}, 76 {12, 16}, 77 {1, 5}, 78 {-3, -4}, 79 }, 80 ans: [][]float64{ 81 {60, -6}, 82 {136, -8}, 83 {202.5, -19.5}, 84 {272.5, -30}, 85 }, 86 }, 87 { 88 m: 3, 89 n: 2, 90 k: 4, 91 alpha: 2, 92 beta: 0.5, 93 a: [][]float64{ 94 {1, 2, 3, 4}, 95 {4, 5, 6, 7}, 96 {8, 9, 10, 11}, 97 }, 98 b: [][]float64{ 99 {1, 5}, 100 {5, -8}, 101 {6, 2}, 102 {8, 10}, 103 }, 104 c: [][]float64{ 105 {4, 8}, 106 {12, 16}, 107 {9, -10}, 108 }, 109 ans: [][]float64{ 110 {124, 74}, 111 {248, 132}, 112 {406.5, 191}, 113 }, 114 }, 115 { 116 m: 3, 117 n: 4, 118 k: 2, 119 alpha: 2, 120 beta: 0.5, 121 a: [][]float64{ 122 {1, 2}, 123 {4, 5}, 124 {8, 9}, 125 }, 126 b: [][]float64{ 127 {1, 5, 2, 1}, 128 {5, -8, 2, 1}, 129 }, 130 c: [][]float64{ 131 {4, 8, 2, 2}, 132 {12, 16, 8, 9}, 133 {9, -10, 10, 10}, 134 }, 135 ans: [][]float64{ 136 {24, -18, 13, 7}, 137 {64, -32, 40, 22.5}, 138 {110.5, -69, 73, 39}, 139 }, 140 }, 141 { 142 m: 2, 143 n: 4, 144 k: 3, 145 alpha: 2, 146 beta: 0.5, 147 a: [][]float64{ 148 {1, 2, 3}, 149 {4, 5, 6}, 150 }, 151 b: [][]float64{ 152 {1, 5, 8, 8}, 153 {5, -8, 9, 10}, 154 {6, 2, -3, 2}, 155 }, 156 c: [][]float64{ 157 {4, 8, 7, 8}, 158 {12, 16, -2, 6}, 159 }, 160 ans: [][]float64{ 161 {60, -6, 37.5, 72}, 162 {136, -8, 117, 191}, 163 }, 164 }, 165 { 166 m: 2, 167 n: 3, 168 k: 4, 169 alpha: 2, 170 beta: 0.5, 171 a: [][]float64{ 172 {1, 2, 3, 4}, 173 {4, 5, 6, 7}, 174 }, 175 b: [][]float64{ 176 {1, 5, 8}, 177 {5, -8, 9}, 178 {6, 2, -3}, 179 {8, 10, 2}, 180 }, 181 c: [][]float64{ 182 {4, 8, 1}, 183 {12, 16, 6}, 184 }, 185 ans: [][]float64{ 186 {124, 74, 50.5}, 187 {248, 132, 149}, 188 }, 189 }, 190 } 191 192 // assumes [][]float64 is actually a matrix 193 func transpose(a [][]float64) [][]float64 { 194 b := make([][]float64, len(a[0])) 195 for i := range b { 196 b[i] = make([]float64, len(a)) 197 for j := range b[i] { 198 b[i][j] = a[j][i] 199 } 200 } 201 return b 202 } 203 204 func TestDgemm(t *testing.T, blasser Dgemmer) { 205 for i, test := range DgemmCases { 206 // Test that it passes row major 207 dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans, 208 test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans) 209 // Try with A transposed 210 dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans, 211 test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans) 212 // Try with B transposed 213 dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans, 214 test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans) 215 // Try with both transposed 216 dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans, 217 test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans) 218 } 219 } 220 221 func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int, 222 alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) { 223 224 aFlat := flatten(a) 225 aCopy := flatten(a) 226 bFlat := flatten(b) 227 bCopy := flatten(b) 228 cFlat := flatten(c) 229 ansFlat := flatten(ans) 230 lda := len(a[0]) 231 ldb := len(b[0]) 232 ldc := len(c[0]) 233 234 // Compute the matrix multiplication 235 blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc) 236 237 if !dSliceEqual(aFlat, aCopy) { 238 t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name) 239 } 240 if !dSliceEqual(bFlat, bCopy) { 241 t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name) 242 } 243 244 if !dSliceTolEqual(ansFlat, cFlat) { 245 t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat) 246 } 247 // TODO: Need to add a sub-slice test where don't use up full matrix 248 }