github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/gonum/sgemm.go (about) 1 // Code generated by "go generate github.com/jingcheng-WU/gonum/blas/gonum”; DO NOT EDIT. 2 3 // Copyright ©2014 The Gonum Authors. All rights reserved. 4 // Use of this source code is governed by a BSD-style 5 // license that can be found in the LICENSE file. 6 7 package gonum 8 9 import ( 10 "runtime" 11 "sync" 12 13 "github.com/jingcheng-WU/gonum/blas" 14 "github.com/jingcheng-WU/gonum/internal/asm/f32" 15 ) 16 17 // Sgemm performs one of the matrix-matrix operations 18 // C = alpha * A * B + beta * C 19 // C = alpha * Aᵀ * B + beta * C 20 // C = alpha * A * Bᵀ + beta * C 21 // C = alpha * Aᵀ * Bᵀ + beta * C 22 // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is 23 // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or 24 // B are transposed. 25 // 26 // Float32 implementations are autogenerated and not directly tested. 27 func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { 28 switch tA { 29 default: 30 panic(badTranspose) 31 case blas.NoTrans, blas.Trans, blas.ConjTrans: 32 } 33 switch tB { 34 default: 35 panic(badTranspose) 36 case blas.NoTrans, blas.Trans, blas.ConjTrans: 37 } 38 if m < 0 { 39 panic(mLT0) 40 } 41 if n < 0 { 42 panic(nLT0) 43 } 44 if k < 0 { 45 panic(kLT0) 46 } 47 aTrans := tA == blas.Trans || tA == blas.ConjTrans 48 if aTrans { 49 if lda < max(1, m) { 50 panic(badLdA) 51 } 52 } else { 53 if lda < max(1, k) { 54 panic(badLdA) 55 } 56 } 57 bTrans := tB == blas.Trans || tB == blas.ConjTrans 58 if bTrans { 59 if ldb < max(1, k) { 60 panic(badLdB) 61 } 62 } else { 63 if ldb < max(1, n) { 64 panic(badLdB) 65 } 66 } 67 if ldc < max(1, n) { 68 panic(badLdC) 69 } 70 71 // Quick return if possible. 72 if m == 0 || n == 0 { 73 return 74 } 75 76 // For zero matrix size the following slice length checks are trivially satisfied. 77 if aTrans { 78 if len(a) < (k-1)*lda+m { 79 panic(shortA) 80 } 81 } else { 82 if len(a) < (m-1)*lda+k { 83 panic(shortA) 84 } 85 } 86 if bTrans { 87 if len(b) < (n-1)*ldb+k { 88 panic(shortB) 89 } 90 } else { 91 if len(b) < (k-1)*ldb+n { 92 panic(shortB) 93 } 94 } 95 if len(c) < (m-1)*ldc+n { 96 panic(shortC) 97 } 98 99 // Quick return if possible. 100 if (alpha == 0 || k == 0) && beta == 1 { 101 return 102 } 103 104 // scale c 105 if beta != 1 { 106 if beta == 0 { 107 for i := 0; i < m; i++ { 108 ctmp := c[i*ldc : i*ldc+n] 109 for j := range ctmp { 110 ctmp[j] = 0 111 } 112 } 113 } else { 114 for i := 0; i < m; i++ { 115 ctmp := c[i*ldc : i*ldc+n] 116 for j := range ctmp { 117 ctmp[j] *= beta 118 } 119 } 120 } 121 } 122 123 sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 124 } 125 126 func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 127 // dgemmParallel computes a parallel matrix multiplication by partitioning 128 // a and b into sub-blocks, and updating c with the multiplication of the sub-block 129 // In all cases, 130 // A = [ A_11 A_12 ... A_1j 131 // A_21 A_22 ... A_2j 132 // ... 133 // A_i1 A_i2 ... A_ij] 134 // 135 // and same for B. All of the submatrix sizes are blockSize×blockSize except 136 // at the edges. 137 // 138 // In all cases, there is one dimension for each matrix along which 139 // C must be updated sequentially. 140 // Cij = \sum_k Aik Bki, (A * B) 141 // Cij = \sum_k Aki Bkj, (Aᵀ * B) 142 // Cij = \sum_k Aik Bjk, (A * Bᵀ) 143 // Cij = \sum_k Aki Bjk, (Aᵀ * Bᵀ) 144 // 145 // This code computes one {i, j} block sequentially along the k dimension, 146 // and computes all of the {i, j} blocks concurrently. This 147 // partitioning allows Cij to be updated in-place without race-conditions. 148 // Instead of launching a goroutine for each possible concurrent computation, 149 // a number of worker goroutines are created and channels are used to pass 150 // available and completed cases. 151 // 152 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix 153 // multiplies, though this code does not copy matrices to attempt to eliminate 154 // cache misses. 155 156 maxKLen := k 157 parBlocks := blocks(m, blockSize) * blocks(n, blockSize) 158 if parBlocks < minParBlock { 159 // The matrix multiplication is small in the dimensions where it can be 160 // computed concurrently. Just do it in serial. 161 sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 162 return 163 } 164 165 // workerLimit acts a number of maximum concurrent workers, 166 // with the limit set to the number of procs available. 167 workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) 168 169 // wg is used to wait for all 170 var wg sync.WaitGroup 171 wg.Add(parBlocks) 172 defer wg.Wait() 173 174 for i := 0; i < m; i += blockSize { 175 for j := 0; j < n; j += blockSize { 176 workerLimit <- struct{}{} 177 go func(i, j int) { 178 defer func() { 179 wg.Done() 180 <-workerLimit 181 }() 182 183 leni := blockSize 184 if i+leni > m { 185 leni = m - i 186 } 187 lenj := blockSize 188 if j+lenj > n { 189 lenj = n - j 190 } 191 192 cSub := sliceView32(c, ldc, i, j, leni, lenj) 193 194 // Compute A_ik B_kj for all k 195 for k := 0; k < maxKLen; k += blockSize { 196 lenk := blockSize 197 if k+lenk > maxKLen { 198 lenk = maxKLen - k 199 } 200 var aSub, bSub []float32 201 if aTrans { 202 aSub = sliceView32(a, lda, k, i, lenk, leni) 203 } else { 204 aSub = sliceView32(a, lda, i, k, leni, lenk) 205 } 206 if bTrans { 207 bSub = sliceView32(b, ldb, j, k, lenj, lenk) 208 } else { 209 bSub = sliceView32(b, ldb, k, j, lenk, lenj) 210 } 211 sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) 212 } 213 }(i, j) 214 } 215 } 216 } 217 218 // sgemmSerial is serial matrix multiply 219 func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 220 switch { 221 case !aTrans && !bTrans: 222 sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 223 return 224 case aTrans && !bTrans: 225 sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 226 return 227 case !aTrans && bTrans: 228 sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 229 return 230 case aTrans && bTrans: 231 sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 232 return 233 default: 234 panic("unreachable") 235 } 236 } 237 238 // sgemmSerial where neither a nor b are transposed 239 func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 240 // This style is used instead of the literal [i*stride +j]) is used because 241 // approximately 5 times faster as of go 1.3. 242 for i := 0; i < m; i++ { 243 ctmp := c[i*ldc : i*ldc+n] 244 for l, v := range a[i*lda : i*lda+k] { 245 tmp := alpha * v 246 if tmp != 0 { 247 f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp) 248 } 249 } 250 } 251 } 252 253 // sgemmSerial where neither a is transposed and b is not 254 func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 255 // This style is used instead of the literal [i*stride +j]) is used because 256 // approximately 5 times faster as of go 1.3. 257 for l := 0; l < k; l++ { 258 btmp := b[l*ldb : l*ldb+n] 259 for i, v := range a[l*lda : l*lda+m] { 260 tmp := alpha * v 261 if tmp != 0 { 262 ctmp := c[i*ldc : i*ldc+n] 263 f32.AxpyUnitary(tmp, btmp, ctmp) 264 } 265 } 266 } 267 } 268 269 // sgemmSerial where neither a is not transposed and b is 270 func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 271 // This style is used instead of the literal [i*stride +j]) is used because 272 // approximately 5 times faster as of go 1.3. 273 for i := 0; i < m; i++ { 274 atmp := a[i*lda : i*lda+k] 275 ctmp := c[i*ldc : i*ldc+n] 276 for j := 0; j < n; j++ { 277 ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k]) 278 } 279 } 280 } 281 282 // sgemmSerial where both are transposed 283 func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 284 // This style is used instead of the literal [i*stride +j]) is used because 285 // approximately 5 times faster as of go 1.3. 286 for l := 0; l < k; l++ { 287 for i, v := range a[l*lda : l*lda+m] { 288 tmp := alpha * v 289 if tmp != 0 { 290 ctmp := c[i*ldc : i*ldc+n] 291 f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) 292 } 293 } 294 } 295 } 296 297 func sliceView32(a []float32, lda, i, j, r, c int) []float32 { 298 return a[i*lda+j : (i+r-1)*lda+j+c] 299 }