github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/gonum/dgemm.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "runtime" 9 "sync" 10 11 "github.com/jingcheng-WU/gonum/blas" 12 "github.com/jingcheng-WU/gonum/internal/asm/f64" 13 ) 14 15 // Dgemm performs one of the matrix-matrix operations 16 // C = alpha * A * B + beta * C 17 // C = alpha * Aᵀ * B + beta * C 18 // C = alpha * A * Bᵀ + beta * C 19 // C = alpha * Aᵀ * Bᵀ + beta * C 20 // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is 21 // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or 22 // B are transposed. 23 func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) { 24 switch tA { 25 default: 26 panic(badTranspose) 27 case blas.NoTrans, blas.Trans, blas.ConjTrans: 28 } 29 switch tB { 30 default: 31 panic(badTranspose) 32 case blas.NoTrans, blas.Trans, blas.ConjTrans: 33 } 34 if m < 0 { 35 panic(mLT0) 36 } 37 if n < 0 { 38 panic(nLT0) 39 } 40 if k < 0 { 41 panic(kLT0) 42 } 43 aTrans := tA == blas.Trans || tA == blas.ConjTrans 44 if aTrans { 45 if lda < max(1, m) { 46 panic(badLdA) 47 } 48 } else { 49 if lda < max(1, k) { 50 panic(badLdA) 51 } 52 } 53 bTrans := tB == blas.Trans || tB == blas.ConjTrans 54 if bTrans { 55 if ldb < max(1, k) { 56 panic(badLdB) 57 } 58 } else { 59 if ldb < max(1, n) { 60 panic(badLdB) 61 } 62 } 63 if ldc < max(1, n) { 64 panic(badLdC) 65 } 66 67 // Quick return if possible. 68 if m == 0 || n == 0 { 69 return 70 } 71 72 // For zero matrix size the following slice length checks are trivially satisfied. 73 if aTrans { 74 if len(a) < (k-1)*lda+m { 75 panic(shortA) 76 } 77 } else { 78 if len(a) < (m-1)*lda+k { 79 panic(shortA) 80 } 81 } 82 if bTrans { 83 if len(b) < (n-1)*ldb+k { 84 panic(shortB) 85 } 86 } else { 87 if len(b) < (k-1)*ldb+n { 88 panic(shortB) 89 } 90 } 91 if len(c) < (m-1)*ldc+n { 92 panic(shortC) 93 } 94 95 // Quick return if possible. 96 if (alpha == 0 || k == 0) && beta == 1 { 97 return 98 } 99 100 // scale c 101 if beta != 1 { 102 if beta == 0 { 103 for i := 0; i < m; i++ { 104 ctmp := c[i*ldc : i*ldc+n] 105 for j := range ctmp { 106 ctmp[j] = 0 107 } 108 } 109 } else { 110 for i := 0; i < m; i++ { 111 ctmp := c[i*ldc : i*ldc+n] 112 for j := range ctmp { 113 ctmp[j] *= beta 114 } 115 } 116 } 117 } 118 119 dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 120 } 121 122 func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 123 // dgemmParallel computes a parallel matrix multiplication by partitioning 124 // a and b into sub-blocks, and updating c with the multiplication of the sub-block 125 // In all cases, 126 // A = [ A_11 A_12 ... A_1j 127 // A_21 A_22 ... A_2j 128 // ... 129 // A_i1 A_i2 ... A_ij] 130 // 131 // and same for B. All of the submatrix sizes are blockSize×blockSize except 132 // at the edges. 133 // 134 // In all cases, there is one dimension for each matrix along which 135 // C must be updated sequentially. 136 // Cij = \sum_k Aik Bki, (A * B) 137 // Cij = \sum_k Aki Bkj, (Aᵀ * B) 138 // Cij = \sum_k Aik Bjk, (A * Bᵀ) 139 // Cij = \sum_k Aki Bjk, (Aᵀ * Bᵀ) 140 // 141 // This code computes one {i, j} block sequentially along the k dimension, 142 // and computes all of the {i, j} blocks concurrently. This 143 // partitioning allows Cij to be updated in-place without race-conditions. 144 // Instead of launching a goroutine for each possible concurrent computation, 145 // a number of worker goroutines are created and channels are used to pass 146 // available and completed cases. 147 // 148 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix 149 // multiplies, though this code does not copy matrices to attempt to eliminate 150 // cache misses. 151 152 maxKLen := k 153 parBlocks := blocks(m, blockSize) * blocks(n, blockSize) 154 if parBlocks < minParBlock { 155 // The matrix multiplication is small in the dimensions where it can be 156 // computed concurrently. Just do it in serial. 157 dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 158 return 159 } 160 161 // workerLimit acts a number of maximum concurrent workers, 162 // with the limit set to the number of procs available. 163 workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0)) 164 165 // wg is used to wait for all 166 var wg sync.WaitGroup 167 wg.Add(parBlocks) 168 defer wg.Wait() 169 170 for i := 0; i < m; i += blockSize { 171 for j := 0; j < n; j += blockSize { 172 workerLimit <- struct{}{} 173 go func(i, j int) { 174 defer func() { 175 wg.Done() 176 <-workerLimit 177 }() 178 179 leni := blockSize 180 if i+leni > m { 181 leni = m - i 182 } 183 lenj := blockSize 184 if j+lenj > n { 185 lenj = n - j 186 } 187 188 cSub := sliceView64(c, ldc, i, j, leni, lenj) 189 190 // Compute A_ik B_kj for all k 191 for k := 0; k < maxKLen; k += blockSize { 192 lenk := blockSize 193 if k+lenk > maxKLen { 194 lenk = maxKLen - k 195 } 196 var aSub, bSub []float64 197 if aTrans { 198 aSub = sliceView64(a, lda, k, i, lenk, leni) 199 } else { 200 aSub = sliceView64(a, lda, i, k, leni, lenk) 201 } 202 if bTrans { 203 bSub = sliceView64(b, ldb, j, k, lenj, lenk) 204 } else { 205 bSub = sliceView64(b, ldb, k, j, lenk, lenj) 206 } 207 dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) 208 } 209 }(i, j) 210 } 211 } 212 } 213 214 // dgemmSerial is serial matrix multiply 215 func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 216 switch { 217 case !aTrans && !bTrans: 218 dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 219 return 220 case aTrans && !bTrans: 221 dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 222 return 223 case !aTrans && bTrans: 224 dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 225 return 226 case aTrans && bTrans: 227 dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 228 return 229 default: 230 panic("unreachable") 231 } 232 } 233 234 // dgemmSerial where neither a nor b are transposed 235 func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 236 // This style is used instead of the literal [i*stride +j]) is used because 237 // approximately 5 times faster as of go 1.3. 238 for i := 0; i < m; i++ { 239 ctmp := c[i*ldc : i*ldc+n] 240 for l, v := range a[i*lda : i*lda+k] { 241 tmp := alpha * v 242 if tmp != 0 { 243 f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp) 244 } 245 } 246 } 247 } 248 249 // dgemmSerial where neither a is transposed and b is not 250 func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 251 // This style is used instead of the literal [i*stride +j]) is used because 252 // approximately 5 times faster as of go 1.3. 253 for l := 0; l < k; l++ { 254 btmp := b[l*ldb : l*ldb+n] 255 for i, v := range a[l*lda : l*lda+m] { 256 tmp := alpha * v 257 if tmp != 0 { 258 ctmp := c[i*ldc : i*ldc+n] 259 f64.AxpyUnitary(tmp, btmp, ctmp) 260 } 261 } 262 } 263 } 264 265 // dgemmSerial where neither a is not transposed and b is 266 func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 267 // This style is used instead of the literal [i*stride +j]) is used because 268 // approximately 5 times faster as of go 1.3. 269 for i := 0; i < m; i++ { 270 atmp := a[i*lda : i*lda+k] 271 ctmp := c[i*ldc : i*ldc+n] 272 for j := 0; j < n; j++ { 273 ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k]) 274 } 275 } 276 } 277 278 // dgemmSerial where both are transposed 279 func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 280 // This style is used instead of the literal [i*stride +j]) is used because 281 // approximately 5 times faster as of go 1.3. 282 for l := 0; l < k; l++ { 283 for i, v := range a[l*lda : l*lda+m] { 284 tmp := alpha * v 285 if tmp != 0 { 286 ctmp := c[i*ldc : i*ldc+n] 287 f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) 288 } 289 } 290 } 291 } 292 293 func sliceView64(a []float64, lda, i, j, r, c int) []float64 { 294 return a[i*lda+j : (i+r-1)*lda+j+c] 295 }