github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/gonum/dgemm.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"runtime"
     9  	"sync"
    10  
    11  	"github.com/jingcheng-WU/gonum/blas"
    12  	"github.com/jingcheng-WU/gonum/internal/asm/f64"
    13  )
    14  
    15  // Dgemm performs one of the matrix-matrix operations
    16  //  C = alpha * A * B + beta * C
    17  //  C = alpha * Aᵀ * B + beta * C
    18  //  C = alpha * A * Bᵀ + beta * C
    19  //  C = alpha * Aᵀ * Bᵀ + beta * C
    20  // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
    21  // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
    22  // B are transposed.
    23  func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
    24  	switch tA {
    25  	default:
    26  		panic(badTranspose)
    27  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    28  	}
    29  	switch tB {
    30  	default:
    31  		panic(badTranspose)
    32  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    33  	}
    34  	if m < 0 {
    35  		panic(mLT0)
    36  	}
    37  	if n < 0 {
    38  		panic(nLT0)
    39  	}
    40  	if k < 0 {
    41  		panic(kLT0)
    42  	}
    43  	aTrans := tA == blas.Trans || tA == blas.ConjTrans
    44  	if aTrans {
    45  		if lda < max(1, m) {
    46  			panic(badLdA)
    47  		}
    48  	} else {
    49  		if lda < max(1, k) {
    50  			panic(badLdA)
    51  		}
    52  	}
    53  	bTrans := tB == blas.Trans || tB == blas.ConjTrans
    54  	if bTrans {
    55  		if ldb < max(1, k) {
    56  			panic(badLdB)
    57  		}
    58  	} else {
    59  		if ldb < max(1, n) {
    60  			panic(badLdB)
    61  		}
    62  	}
    63  	if ldc < max(1, n) {
    64  		panic(badLdC)
    65  	}
    66  
    67  	// Quick return if possible.
    68  	if m == 0 || n == 0 {
    69  		return
    70  	}
    71  
    72  	// For zero matrix size the following slice length checks are trivially satisfied.
    73  	if aTrans {
    74  		if len(a) < (k-1)*lda+m {
    75  			panic(shortA)
    76  		}
    77  	} else {
    78  		if len(a) < (m-1)*lda+k {
    79  			panic(shortA)
    80  		}
    81  	}
    82  	if bTrans {
    83  		if len(b) < (n-1)*ldb+k {
    84  			panic(shortB)
    85  		}
    86  	} else {
    87  		if len(b) < (k-1)*ldb+n {
    88  			panic(shortB)
    89  		}
    90  	}
    91  	if len(c) < (m-1)*ldc+n {
    92  		panic(shortC)
    93  	}
    94  
    95  	// Quick return if possible.
    96  	if (alpha == 0 || k == 0) && beta == 1 {
    97  		return
    98  	}
    99  
   100  	// scale c
   101  	if beta != 1 {
   102  		if beta == 0 {
   103  			for i := 0; i < m; i++ {
   104  				ctmp := c[i*ldc : i*ldc+n]
   105  				for j := range ctmp {
   106  					ctmp[j] = 0
   107  				}
   108  			}
   109  		} else {
   110  			for i := 0; i < m; i++ {
   111  				ctmp := c[i*ldc : i*ldc+n]
   112  				for j := range ctmp {
   113  					ctmp[j] *= beta
   114  				}
   115  			}
   116  		}
   117  	}
   118  
   119  	dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   120  }
   121  
   122  func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   123  	// dgemmParallel computes a parallel matrix multiplication by partitioning
   124  	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
   125  	// In all cases,
   126  	// A = [ 	A_11	A_12 ... 	A_1j
   127  	//			A_21	A_22 ...	A_2j
   128  	//				...
   129  	//			A_i1	A_i2 ...	A_ij]
   130  	//
   131  	// and same for B. All of the submatrix sizes are blockSize×blockSize except
   132  	// at the edges.
   133  	//
   134  	// In all cases, there is one dimension for each matrix along which
   135  	// C must be updated sequentially.
   136  	// Cij = \sum_k Aik Bki,	(A * B)
   137  	// Cij = \sum_k Aki Bkj,	(Aᵀ * B)
   138  	// Cij = \sum_k Aik Bjk,	(A * Bᵀ)
   139  	// Cij = \sum_k Aki Bjk,	(Aᵀ * Bᵀ)
   140  	//
   141  	// This code computes one {i, j} block sequentially along the k dimension,
   142  	// and computes all of the {i, j} blocks concurrently. This
   143  	// partitioning allows Cij to be updated in-place without race-conditions.
   144  	// Instead of launching a goroutine for each possible concurrent computation,
   145  	// a number of worker goroutines are created and channels are used to pass
   146  	// available and completed cases.
   147  	//
   148  	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
   149  	// multiplies, though this code does not copy matrices to attempt to eliminate
   150  	// cache misses.
   151  
   152  	maxKLen := k
   153  	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
   154  	if parBlocks < minParBlock {
   155  		// The matrix multiplication is small in the dimensions where it can be
   156  		// computed concurrently. Just do it in serial.
   157  		dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   158  		return
   159  	}
   160  
   161  	// workerLimit acts a number of maximum concurrent workers,
   162  	// with the limit set to the number of procs available.
   163  	workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
   164  
   165  	// wg is used to wait for all
   166  	var wg sync.WaitGroup
   167  	wg.Add(parBlocks)
   168  	defer wg.Wait()
   169  
   170  	for i := 0; i < m; i += blockSize {
   171  		for j := 0; j < n; j += blockSize {
   172  			workerLimit <- struct{}{}
   173  			go func(i, j int) {
   174  				defer func() {
   175  					wg.Done()
   176  					<-workerLimit
   177  				}()
   178  
   179  				leni := blockSize
   180  				if i+leni > m {
   181  					leni = m - i
   182  				}
   183  				lenj := blockSize
   184  				if j+lenj > n {
   185  					lenj = n - j
   186  				}
   187  
   188  				cSub := sliceView64(c, ldc, i, j, leni, lenj)
   189  
   190  				// Compute A_ik B_kj for all k
   191  				for k := 0; k < maxKLen; k += blockSize {
   192  					lenk := blockSize
   193  					if k+lenk > maxKLen {
   194  						lenk = maxKLen - k
   195  					}
   196  					var aSub, bSub []float64
   197  					if aTrans {
   198  						aSub = sliceView64(a, lda, k, i, lenk, leni)
   199  					} else {
   200  						aSub = sliceView64(a, lda, i, k, leni, lenk)
   201  					}
   202  					if bTrans {
   203  						bSub = sliceView64(b, ldb, j, k, lenj, lenk)
   204  					} else {
   205  						bSub = sliceView64(b, ldb, k, j, lenk, lenj)
   206  					}
   207  					dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
   208  				}
   209  			}(i, j)
   210  		}
   211  	}
   212  }
   213  
   214  // dgemmSerial is serial matrix multiply
   215  func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   216  	switch {
   217  	case !aTrans && !bTrans:
   218  		dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   219  		return
   220  	case aTrans && !bTrans:
   221  		dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   222  		return
   223  	case !aTrans && bTrans:
   224  		dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   225  		return
   226  	case aTrans && bTrans:
   227  		dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   228  		return
   229  	default:
   230  		panic("unreachable")
   231  	}
   232  }
   233  
   234  // dgemmSerial where neither a nor b are transposed
   235  func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   236  	// This style is used instead of the literal [i*stride +j]) is used because
   237  	// approximately 5 times faster as of go 1.3.
   238  	for i := 0; i < m; i++ {
   239  		ctmp := c[i*ldc : i*ldc+n]
   240  		for l, v := range a[i*lda : i*lda+k] {
   241  			tmp := alpha * v
   242  			if tmp != 0 {
   243  				f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
   244  			}
   245  		}
   246  	}
   247  }
   248  
   249  // dgemmSerial where neither a is transposed and b is not
   250  func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   251  	// This style is used instead of the literal [i*stride +j]) is used because
   252  	// approximately 5 times faster as of go 1.3.
   253  	for l := 0; l < k; l++ {
   254  		btmp := b[l*ldb : l*ldb+n]
   255  		for i, v := range a[l*lda : l*lda+m] {
   256  			tmp := alpha * v
   257  			if tmp != 0 {
   258  				ctmp := c[i*ldc : i*ldc+n]
   259  				f64.AxpyUnitary(tmp, btmp, ctmp)
   260  			}
   261  		}
   262  	}
   263  }
   264  
   265  // dgemmSerial where neither a is not transposed and b is
   266  func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   267  	// This style is used instead of the literal [i*stride +j]) is used because
   268  	// approximately 5 times faster as of go 1.3.
   269  	for i := 0; i < m; i++ {
   270  		atmp := a[i*lda : i*lda+k]
   271  		ctmp := c[i*ldc : i*ldc+n]
   272  		for j := 0; j < n; j++ {
   273  			ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
   274  		}
   275  	}
   276  }
   277  
   278  // dgemmSerial where both are transposed
   279  func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   280  	// This style is used instead of the literal [i*stride +j]) is used because
   281  	// approximately 5 times faster as of go 1.3.
   282  	for l := 0; l < k; l++ {
   283  		for i, v := range a[l*lda : l*lda+m] {
   284  			tmp := alpha * v
   285  			if tmp != 0 {
   286  				ctmp := c[i*ldc : i*ldc+n]
   287  				f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
   288  			}
   289  		}
   290  	}
   291  }
   292  
   293  func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
   294  	return a[i*lda+j : (i+r-1)*lda+j+c]
   295  }