gonum.org/v1/gonum@v0.14.0/blas/gonum/dgemm.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "runtime" 9 "sync" 10 11 "gonum.org/v1/gonum/blas" 12 "gonum.org/v1/gonum/internal/asm/f64" 13 ) 14 15 // Dgemm performs one of the matrix-matrix operations 16 // 17 // C = alpha * A * B + beta * C 18 // C = alpha * Aᵀ * B + beta * C 19 // C = alpha * A * Bᵀ + beta * C 20 // C = alpha * Aᵀ * Bᵀ + beta * C 21 // 22 // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is 23 // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or 24 // B are transposed. 25 func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) { 26 switch tA { 27 default: 28 panic(badTranspose) 29 case blas.NoTrans, blas.Trans, blas.ConjTrans: 30 } 31 switch tB { 32 default: 33 panic(badTranspose) 34 case blas.NoTrans, blas.Trans, blas.ConjTrans: 35 } 36 if m < 0 { 37 panic(mLT0) 38 } 39 if n < 0 { 40 panic(nLT0) 41 } 42 if k < 0 { 43 panic(kLT0) 44 } 45 aTrans := tA == blas.Trans || tA == blas.ConjTrans 46 if aTrans { 47 if lda < max(1, m) { 48 panic(badLdA) 49 } 50 } else { 51 if lda < max(1, k) { 52 panic(badLdA) 53 } 54 } 55 bTrans := tB == blas.Trans || tB == blas.ConjTrans 56 if bTrans { 57 if ldb < max(1, k) { 58 panic(badLdB) 59 } 60 } else { 61 if ldb < max(1, n) { 62 panic(badLdB) 63 } 64 } 65 if ldc < max(1, n) { 66 panic(badLdC) 67 } 68 69 // Quick return if possible. 70 if m == 0 || n == 0 { 71 return 72 } 73 74 // For zero matrix size the following slice length checks are trivially satisfied. 75 if aTrans { 76 if len(a) < (k-1)*lda+m { 77 panic(shortA) 78 } 79 } else { 80 if len(a) < (m-1)*lda+k { 81 panic(shortA) 82 } 83 } 84 if bTrans { 85 if len(b) < (n-1)*ldb+k { 86 panic(shortB) 87 } 88 } else { 89 if len(b) < (k-1)*ldb+n { 90 panic(shortB) 91 } 92 } 93 if len(c) < (m-1)*ldc+n { 94 panic(shortC) 95 } 96 97 // Quick return if possible. 98 if (alpha == 0 || k == 0) && beta == 1 { 99 return 100 } 101 102 // scale c 103 if beta != 1 { 104 if beta == 0 { 105 for i := 0; i < m; i++ { 106 ctmp := c[i*ldc : i*ldc+n] 107 for j := range ctmp { 108 ctmp[j] = 0 109 } 110 } 111 } else { 112 for i := 0; i < m; i++ { 113 ctmp := c[i*ldc : i*ldc+n] 114 for j := range ctmp { 115 ctmp[j] *= beta 116 } 117 } 118 } 119 } 120 121 dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 122 } 123 124 func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 125 // dgemmParallel computes a parallel matrix multiplication by partitioning 126 // a and b into sub-blocks, and updating c with the multiplication of the sub-block 127 // In all cases, 128 // A = [ A_11 A_12 ... A_1j 129 // A_21 A_22 ... A_2j 130 // ... 131 // A_i1 A_i2 ... A_ij] 132 // 133 // and same for B. All of the submatrix sizes are blockSize×blockSize except 134 // at the edges. 135 // 136 // In all cases, there is one dimension for each matrix along which 137 // C must be updated sequentially. 138 // Cij = \sum_k Aik Bki, (A * B) 139 // Cij = \sum_k Aki Bkj, (Aᵀ * B) 140 // Cij = \sum_k Aik Bjk, (A * Bᵀ) 141 // Cij = \sum_k Aki Bjk, (Aᵀ * Bᵀ) 142 // 143 // This code computes one {i, j} block sequentially along the k dimension, 144 // and computes all of the {i, j} blocks concurrently. This 145 // partitioning allows Cij to be updated in-place without race-conditions. 146 // Instead of launching a goroutine for each possible concurrent computation, 147 // a number of worker goroutines are created and channels are used to pass 148 // available and completed cases. 149 // 150 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix 151 // multiplies, though this code does not copy matrices to attempt to eliminate 152 // cache misses. 153 154 maxKLen := k 155 parBlocks := blocks(m, blockSize) * blocks(n, blockSize) 156 if parBlocks < minParBlock { 157 // The matrix multiplication is small in the dimensions where it can be 158 // computed concurrently. Just do it in serial. 159 dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 160 return 161 } 162 163 // workerLimit acts a number of maximum concurrent workers, 164 // with the limit set to the number of procs available. 165 workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) 166 167 // wg is used to wait for all 168 var wg sync.WaitGroup 169 wg.Add(parBlocks) 170 defer wg.Wait() 171 172 for i := 0; i < m; i += blockSize { 173 for j := 0; j < n; j += blockSize { 174 workerLimit <- struct{}{} 175 go func(i, j int) { 176 defer func() { 177 wg.Done() 178 <-workerLimit 179 }() 180 181 leni := blockSize 182 if i+leni > m { 183 leni = m - i 184 } 185 lenj := blockSize 186 if j+lenj > n { 187 lenj = n - j 188 } 189 190 cSub := sliceView64(c, ldc, i, j, leni, lenj) 191 192 // Compute A_ik B_kj for all k 193 for k := 0; k < maxKLen; k += blockSize { 194 lenk := blockSize 195 if k+lenk > maxKLen { 196 lenk = maxKLen - k 197 } 198 var aSub, bSub []float64 199 if aTrans { 200 aSub = sliceView64(a, lda, k, i, lenk, leni) 201 } else { 202 aSub = sliceView64(a, lda, i, k, leni, lenk) 203 } 204 if bTrans { 205 bSub = sliceView64(b, ldb, j, k, lenj, lenk) 206 } else { 207 bSub = sliceView64(b, ldb, k, j, lenk, lenj) 208 } 209 dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) 210 } 211 }(i, j) 212 } 213 } 214 } 215 216 // dgemmSerial is serial matrix multiply 217 func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 218 switch { 219 case !aTrans && !bTrans: 220 dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 221 return 222 case aTrans && !bTrans: 223 dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 224 return 225 case !aTrans && bTrans: 226 dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 227 return 228 case aTrans && bTrans: 229 dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 230 return 231 default: 232 panic("unreachable") 233 } 234 } 235 236 // dgemmSerial where neither a nor b are transposed 237 func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 238 // This style is used instead of the literal [i*stride +j]) is used because 239 // approximately 5 times faster as of go 1.3. 240 for i := 0; i < m; i++ { 241 ctmp := c[i*ldc : i*ldc+n] 242 for l, v := range a[i*lda : i*lda+k] { 243 tmp := alpha * v 244 if tmp != 0 { 245 f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp) 246 } 247 } 248 } 249 } 250 251 // dgemmSerial where neither a is transposed and b is not 252 func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 253 // This style is used instead of the literal [i*stride +j]) is used because 254 // approximately 5 times faster as of go 1.3. 255 for l := 0; l < k; l++ { 256 btmp := b[l*ldb : l*ldb+n] 257 for i, v := range a[l*lda : l*lda+m] { 258 tmp := alpha * v 259 if tmp != 0 { 260 ctmp := c[i*ldc : i*ldc+n] 261 f64.AxpyUnitary(tmp, btmp, ctmp) 262 } 263 } 264 } 265 } 266 267 // dgemmSerial where neither a is not transposed and b is 268 func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 269 // This style is used instead of the literal [i*stride +j]) is used because 270 // approximately 5 times faster as of go 1.3. 271 for i := 0; i < m; i++ { 272 atmp := a[i*lda : i*lda+k] 273 ctmp := c[i*ldc : i*ldc+n] 274 for j := 0; j < n; j++ { 275 ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k]) 276 } 277 } 278 } 279 280 // dgemmSerial where both are transposed 281 func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 282 // This style is used instead of the literal [i*stride +j]) is used because 283 // approximately 5 times faster as of go 1.3. 284 for l := 0; l < k; l++ { 285 for i, v := range a[l*lda : l*lda+m] { 286 tmp := alpha * v 287 if tmp != 0 { 288 ctmp := c[i*ldc : i*ldc+n] 289 f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) 290 } 291 } 292 } 293 } 294 295 func sliceView64(a []float64, lda, i, j, r, c int) []float64 { 296 return a[i*lda+j : (i+r-1)*lda+j+c] 297 }