gonum.org/v1/gonum@v0.14.0/blas/gonum/sgemm.go (about) 1 // Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT. 2 3 // Copyright ©2014 The Gonum Authors. All rights reserved. 4 // Use of this source code is governed by a BSD-style 5 // license that can be found in the LICENSE file. 6 7 package gonum 8 9 import ( 10 "runtime" 11 "sync" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/internal/asm/f32" 15 ) 16 17 // Sgemm performs one of the matrix-matrix operations 18 // 19 // C = alpha * A * B + beta * C 20 // C = alpha * Aᵀ * B + beta * C 21 // C = alpha * A * Bᵀ + beta * C 22 // C = alpha * Aᵀ * Bᵀ + beta * C 23 // 24 // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is 25 // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or 26 // B are transposed. 27 // 28 // Float32 implementations are autogenerated and not directly tested. 29 func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { 30 switch tA { 31 default: 32 panic(badTranspose) 33 case blas.NoTrans, blas.Trans, blas.ConjTrans: 34 } 35 switch tB { 36 default: 37 panic(badTranspose) 38 case blas.NoTrans, blas.Trans, blas.ConjTrans: 39 } 40 if m < 0 { 41 panic(mLT0) 42 } 43 if n < 0 { 44 panic(nLT0) 45 } 46 if k < 0 { 47 panic(kLT0) 48 } 49 aTrans := tA == blas.Trans || tA == blas.ConjTrans 50 if aTrans { 51 if lda < max(1, m) { 52 panic(badLdA) 53 } 54 } else { 55 if lda < max(1, k) { 56 panic(badLdA) 57 } 58 } 59 bTrans := tB == blas.Trans || tB == blas.ConjTrans 60 if bTrans { 61 if ldb < max(1, k) { 62 panic(badLdB) 63 } 64 } else { 65 if ldb < max(1, n) { 66 panic(badLdB) 67 } 68 } 69 if ldc < max(1, n) { 70 panic(badLdC) 71 } 72 73 // Quick return if possible. 74 if m == 0 || n == 0 { 75 return 76 } 77 78 // For zero matrix size the following slice length checks are trivially satisfied. 79 if aTrans { 80 if len(a) < (k-1)*lda+m { 81 panic(shortA) 82 } 83 } else { 84 if len(a) < (m-1)*lda+k { 85 panic(shortA) 86 } 87 } 88 if bTrans { 89 if len(b) < (n-1)*ldb+k { 90 panic(shortB) 91 } 92 } else { 93 if len(b) < (k-1)*ldb+n { 94 panic(shortB) 95 } 96 } 97 if len(c) < (m-1)*ldc+n { 98 panic(shortC) 99 } 100 101 // Quick return if possible. 102 if (alpha == 0 || k == 0) && beta == 1 { 103 return 104 } 105 106 // scale c 107 if beta != 1 { 108 if beta == 0 { 109 for i := 0; i < m; i++ { 110 ctmp := c[i*ldc : i*ldc+n] 111 for j := range ctmp { 112 ctmp[j] = 0 113 } 114 } 115 } else { 116 for i := 0; i < m; i++ { 117 ctmp := c[i*ldc : i*ldc+n] 118 for j := range ctmp { 119 ctmp[j] *= beta 120 } 121 } 122 } 123 } 124 125 sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 126 } 127 128 func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 129 // dgemmParallel computes a parallel matrix multiplication by partitioning 130 // a and b into sub-blocks, and updating c with the multiplication of the sub-block 131 // In all cases, 132 // A = [ A_11 A_12 ... A_1j 133 // A_21 A_22 ... A_2j 134 // ... 135 // A_i1 A_i2 ... A_ij] 136 // 137 // and same for B. All of the submatrix sizes are blockSize×blockSize except 138 // at the edges. 139 // 140 // In all cases, there is one dimension for each matrix along which 141 // C must be updated sequentially. 142 // Cij = \sum_k Aik Bki, (A * B) 143 // Cij = \sum_k Aki Bkj, (Aᵀ * B) 144 // Cij = \sum_k Aik Bjk, (A * Bᵀ) 145 // Cij = \sum_k Aki Bjk, (Aᵀ * Bᵀ) 146 // 147 // This code computes one {i, j} block sequentially along the k dimension, 148 // and computes all of the {i, j} blocks concurrently. This 149 // partitioning allows Cij to be updated in-place without race-conditions. 150 // Instead of launching a goroutine for each possible concurrent computation, 151 // a number of worker goroutines are created and channels are used to pass 152 // available and completed cases. 153 // 154 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix 155 // multiplies, though this code does not copy matrices to attempt to eliminate 156 // cache misses. 157 158 maxKLen := k 159 parBlocks := blocks(m, blockSize) * blocks(n, blockSize) 160 if parBlocks < minParBlock { 161 // The matrix multiplication is small in the dimensions where it can be 162 // computed concurrently. Just do it in serial. 163 sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 164 return 165 } 166 167 // workerLimit acts a number of maximum concurrent workers, 168 // with the limit set to the number of procs available. 169 workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) 170 171 // wg is used to wait for all 172 var wg sync.WaitGroup 173 wg.Add(parBlocks) 174 defer wg.Wait() 175 176 for i := 0; i < m; i += blockSize { 177 for j := 0; j < n; j += blockSize { 178 workerLimit <- struct{}{} 179 go func(i, j int) { 180 defer func() { 181 wg.Done() 182 <-workerLimit 183 }() 184 185 leni := blockSize 186 if i+leni > m { 187 leni = m - i 188 } 189 lenj := blockSize 190 if j+lenj > n { 191 lenj = n - j 192 } 193 194 cSub := sliceView32(c, ldc, i, j, leni, lenj) 195 196 // Compute A_ik B_kj for all k 197 for k := 0; k < maxKLen; k += blockSize { 198 lenk := blockSize 199 if k+lenk > maxKLen { 200 lenk = maxKLen - k 201 } 202 var aSub, bSub []float32 203 if aTrans { 204 aSub = sliceView32(a, lda, k, i, lenk, leni) 205 } else { 206 aSub = sliceView32(a, lda, i, k, leni, lenk) 207 } 208 if bTrans { 209 bSub = sliceView32(b, ldb, j, k, lenj, lenk) 210 } else { 211 bSub = sliceView32(b, ldb, k, j, lenk, lenj) 212 } 213 sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) 214 } 215 }(i, j) 216 } 217 } 218 } 219 220 // sgemmSerial is serial matrix multiply 221 func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 222 switch { 223 case !aTrans && !bTrans: 224 sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 225 return 226 case aTrans && !bTrans: 227 sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 228 return 229 case !aTrans && bTrans: 230 sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 231 return 232 case aTrans && bTrans: 233 sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 234 return 235 default: 236 panic("unreachable") 237 } 238 } 239 240 // sgemmSerial where neither a nor b are transposed 241 func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 242 // This style is used instead of the literal [i*stride +j]) is used because 243 // approximately 5 times faster as of go 1.3. 244 for i := 0; i < m; i++ { 245 ctmp := c[i*ldc : i*ldc+n] 246 for l, v := range a[i*lda : i*lda+k] { 247 tmp := alpha * v 248 if tmp != 0 { 249 f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp) 250 } 251 } 252 } 253 } 254 255 // sgemmSerial where neither a is transposed and b is not 256 func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 257 // This style is used instead of the literal [i*stride +j]) is used because 258 // approximately 5 times faster as of go 1.3. 259 for l := 0; l < k; l++ { 260 btmp := b[l*ldb : l*ldb+n] 261 for i, v := range a[l*lda : l*lda+m] { 262 tmp := alpha * v 263 if tmp != 0 { 264 ctmp := c[i*ldc : i*ldc+n] 265 f32.AxpyUnitary(tmp, btmp, ctmp) 266 } 267 } 268 } 269 } 270 271 // sgemmSerial where neither a is not transposed and b is 272 func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 273 // This style is used instead of the literal [i*stride +j]) is used because 274 // approximately 5 times faster as of go 1.3. 275 for i := 0; i < m; i++ { 276 atmp := a[i*lda : i*lda+k] 277 ctmp := c[i*ldc : i*ldc+n] 278 for j := 0; j < n; j++ { 279 ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k]) 280 } 281 } 282 } 283 284 // sgemmSerial where both are transposed 285 func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { 286 // This style is used instead of the literal [i*stride +j]) is used because 287 // approximately 5 times faster as of go 1.3. 288 for l := 0; l < k; l++ { 289 for i, v := range a[l*lda : l*lda+m] { 290 tmp := alpha * v 291 if tmp != 0 { 292 ctmp := c[i*ldc : i*ldc+n] 293 f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) 294 } 295 } 296 } 297 } 298 299 func sliceView32(a []float32, lda, i, j, r, c int) []float32 { 300 return a[i*lda+j : (i+r-1)*lda+j+c] 301 }