gonum.org/v1/gonum@v0.14.0/blas/gonum/dgemm.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"runtime"
     9  	"sync"
    10  
    11  	"gonum.org/v1/gonum/blas"
    12  	"gonum.org/v1/gonum/internal/asm/f64"
    13  )
    14  
    15  // Dgemm performs one of the matrix-matrix operations
    16  //
    17  //	C = alpha * A * B + beta * C
    18  //	C = alpha * Aᵀ * B + beta * C
    19  //	C = alpha * A * Bᵀ + beta * C
    20  //	C = alpha * Aᵀ * Bᵀ + beta * C
    21  //
    22  // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
    23  // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
    24  // B are transposed.
    25  func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
    26  	switch tA {
    27  	default:
    28  		panic(badTranspose)
    29  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    30  	}
    31  	switch tB {
    32  	default:
    33  		panic(badTranspose)
    34  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    35  	}
    36  	if m < 0 {
    37  		panic(mLT0)
    38  	}
    39  	if n < 0 {
    40  		panic(nLT0)
    41  	}
    42  	if k < 0 {
    43  		panic(kLT0)
    44  	}
    45  	aTrans := tA == blas.Trans || tA == blas.ConjTrans
    46  	if aTrans {
    47  		if lda < max(1, m) {
    48  			panic(badLdA)
    49  		}
    50  	} else {
    51  		if lda < max(1, k) {
    52  			panic(badLdA)
    53  		}
    54  	}
    55  	bTrans := tB == blas.Trans || tB == blas.ConjTrans
    56  	if bTrans {
    57  		if ldb < max(1, k) {
    58  			panic(badLdB)
    59  		}
    60  	} else {
    61  		if ldb < max(1, n) {
    62  			panic(badLdB)
    63  		}
    64  	}
    65  	if ldc < max(1, n) {
    66  		panic(badLdC)
    67  	}
    68  
    69  	// Quick return if possible.
    70  	if m == 0 || n == 0 {
    71  		return
    72  	}
    73  
    74  	// For zero matrix size the following slice length checks are trivially satisfied.
    75  	if aTrans {
    76  		if len(a) < (k-1)*lda+m {
    77  			panic(shortA)
    78  		}
    79  	} else {
    80  		if len(a) < (m-1)*lda+k {
    81  			panic(shortA)
    82  		}
    83  	}
    84  	if bTrans {
    85  		if len(b) < (n-1)*ldb+k {
    86  			panic(shortB)
    87  		}
    88  	} else {
    89  		if len(b) < (k-1)*ldb+n {
    90  			panic(shortB)
    91  		}
    92  	}
    93  	if len(c) < (m-1)*ldc+n {
    94  		panic(shortC)
    95  	}
    96  
    97  	// Quick return if possible.
    98  	if (alpha == 0 || k == 0) && beta == 1 {
    99  		return
   100  	}
   101  
   102  	// scale c
   103  	if beta != 1 {
   104  		if beta == 0 {
   105  			for i := 0; i < m; i++ {
   106  				ctmp := c[i*ldc : i*ldc+n]
   107  				for j := range ctmp {
   108  					ctmp[j] = 0
   109  				}
   110  			}
   111  		} else {
   112  			for i := 0; i < m; i++ {
   113  				ctmp := c[i*ldc : i*ldc+n]
   114  				for j := range ctmp {
   115  					ctmp[j] *= beta
   116  				}
   117  			}
   118  		}
   119  	}
   120  
   121  	dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   122  }
   123  
   124  func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   125  	// dgemmParallel computes a parallel matrix multiplication by partitioning
   126  	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
   127  	// In all cases,
   128  	// A = [ 	A_11	A_12 ... 	A_1j
   129  	//			A_21	A_22 ...	A_2j
   130  	//				...
   131  	//			A_i1	A_i2 ...	A_ij]
   132  	//
   133  	// and same for B. All of the submatrix sizes are blockSize×blockSize except
   134  	// at the edges.
   135  	//
   136  	// In all cases, there is one dimension for each matrix along which
   137  	// C must be updated sequentially.
   138  	// Cij = \sum_k Aik Bki,	(A * B)
   139  	// Cij = \sum_k Aki Bkj,	(Aᵀ * B)
   140  	// Cij = \sum_k Aik Bjk,	(A * Bᵀ)
   141  	// Cij = \sum_k Aki Bjk,	(Aᵀ * Bᵀ)
   142  	//
   143  	// This code computes one {i, j} block sequentially along the k dimension,
   144  	// and computes all of the {i, j} blocks concurrently. This
   145  	// partitioning allows Cij to be updated in-place without race-conditions.
   146  	// Instead of launching a goroutine for each possible concurrent computation,
   147  	// a number of worker goroutines are created and channels are used to pass
   148  	// available and completed cases.
   149  	//
   150  	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
   151  	// multiplies, though this code does not copy matrices to attempt to eliminate
   152  	// cache misses.
   153  
   154  	maxKLen := k
   155  	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
   156  	if parBlocks < minParBlock {
   157  		// The matrix multiplication is small in the dimensions where it can be
   158  		// computed concurrently. Just do it in serial.
   159  		dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   160  		return
   161  	}
   162  
   163  	// workerLimit acts a number of maximum concurrent workers,
   164  	// with the limit set to the number of procs available.
   165  	workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
   166  
   167  	// wg is used to wait for all
   168  	var wg sync.WaitGroup
   169  	wg.Add(parBlocks)
   170  	defer wg.Wait()
   171  
   172  	for i := 0; i < m; i += blockSize {
   173  		for j := 0; j < n; j += blockSize {
   174  			workerLimit <- struct{}{}
   175  			go func(i, j int) {
   176  				defer func() {
   177  					wg.Done()
   178  					<-workerLimit
   179  				}()
   180  
   181  				leni := blockSize
   182  				if i+leni > m {
   183  					leni = m - i
   184  				}
   185  				lenj := blockSize
   186  				if j+lenj > n {
   187  					lenj = n - j
   188  				}
   189  
   190  				cSub := sliceView64(c, ldc, i, j, leni, lenj)
   191  
   192  				// Compute A_ik B_kj for all k
   193  				for k := 0; k < maxKLen; k += blockSize {
   194  					lenk := blockSize
   195  					if k+lenk > maxKLen {
   196  						lenk = maxKLen - k
   197  					}
   198  					var aSub, bSub []float64
   199  					if aTrans {
   200  						aSub = sliceView64(a, lda, k, i, lenk, leni)
   201  					} else {
   202  						aSub = sliceView64(a, lda, i, k, leni, lenk)
   203  					}
   204  					if bTrans {
   205  						bSub = sliceView64(b, ldb, j, k, lenj, lenk)
   206  					} else {
   207  						bSub = sliceView64(b, ldb, k, j, lenk, lenj)
   208  					}
   209  					dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
   210  				}
   211  			}(i, j)
   212  		}
   213  	}
   214  }
   215  
   216  // dgemmSerial is serial matrix multiply
   217  func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   218  	switch {
   219  	case !aTrans && !bTrans:
   220  		dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   221  		return
   222  	case aTrans && !bTrans:
   223  		dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   224  		return
   225  	case !aTrans && bTrans:
   226  		dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   227  		return
   228  	case aTrans && bTrans:
   229  		dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   230  		return
   231  	default:
   232  		panic("unreachable")
   233  	}
   234  }
   235  
   236  // dgemmSerial where neither a nor b are transposed
   237  func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   238  	// This style is used instead of the literal [i*stride +j]) is used because
   239  	// approximately 5 times faster as of go 1.3.
   240  	for i := 0; i < m; i++ {
   241  		ctmp := c[i*ldc : i*ldc+n]
   242  		for l, v := range a[i*lda : i*lda+k] {
   243  			tmp := alpha * v
   244  			if tmp != 0 {
   245  				f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
   246  			}
   247  		}
   248  	}
   249  }
   250  
   251  // dgemmSerial where neither a is transposed and b is not
   252  func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   253  	// This style is used instead of the literal [i*stride +j]) is used because
   254  	// approximately 5 times faster as of go 1.3.
   255  	for l := 0; l < k; l++ {
   256  		btmp := b[l*ldb : l*ldb+n]
   257  		for i, v := range a[l*lda : l*lda+m] {
   258  			tmp := alpha * v
   259  			if tmp != 0 {
   260  				ctmp := c[i*ldc : i*ldc+n]
   261  				f64.AxpyUnitary(tmp, btmp, ctmp)
   262  			}
   263  		}
   264  	}
   265  }
   266  
   267  // dgemmSerial where neither a is not transposed and b is
   268  func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   269  	// This style is used instead of the literal [i*stride +j]) is used because
   270  	// approximately 5 times faster as of go 1.3.
   271  	for i := 0; i < m; i++ {
   272  		atmp := a[i*lda : i*lda+k]
   273  		ctmp := c[i*ldc : i*ldc+n]
   274  		for j := 0; j < n; j++ {
   275  			ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
   276  		}
   277  	}
   278  }
   279  
   280  // dgemmSerial where both are transposed
   281  func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
   282  	// This style is used instead of the literal [i*stride +j]) is used because
   283  	// approximately 5 times faster as of go 1.3.
   284  	for l := 0; l < k; l++ {
   285  		for i, v := range a[l*lda : l*lda+m] {
   286  			tmp := alpha * v
   287  			if tmp != 0 {
   288  				ctmp := c[i*ldc : i*ldc+n]
   289  				f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
   290  			}
   291  		}
   292  	}
   293  }
   294  
   295  func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
   296  	return a[i*lda+j : (i+r-1)*lda+j+c]
   297  }