gonum.org/v1/gonum@v0.14.0/blas/gonum/sgemm.go (about)

     1  // Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT.
     2  
     3  // Copyright ©2014 The Gonum Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package gonum
     8  
     9  import (
    10  	"runtime"
    11  	"sync"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  	"gonum.org/v1/gonum/internal/asm/f32"
    15  )
    16  
    17  // Sgemm performs one of the matrix-matrix operations
    18  //
    19  //	C = alpha * A * B + beta * C
    20  //	C = alpha * Aᵀ * B + beta * C
    21  //	C = alpha * A * Bᵀ + beta * C
    22  //	C = alpha * Aᵀ * Bᵀ + beta * C
    23  //
    24  // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
    25  // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
    26  // B are transposed.
    27  //
    28  // Float32 implementations are autogenerated and not directly tested.
    29  func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
    30  	switch tA {
    31  	default:
    32  		panic(badTranspose)
    33  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    34  	}
    35  	switch tB {
    36  	default:
    37  		panic(badTranspose)
    38  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    39  	}
    40  	if m < 0 {
    41  		panic(mLT0)
    42  	}
    43  	if n < 0 {
    44  		panic(nLT0)
    45  	}
    46  	if k < 0 {
    47  		panic(kLT0)
    48  	}
    49  	aTrans := tA == blas.Trans || tA == blas.ConjTrans
    50  	if aTrans {
    51  		if lda < max(1, m) {
    52  			panic(badLdA)
    53  		}
    54  	} else {
    55  		if lda < max(1, k) {
    56  			panic(badLdA)
    57  		}
    58  	}
    59  	bTrans := tB == blas.Trans || tB == blas.ConjTrans
    60  	if bTrans {
    61  		if ldb < max(1, k) {
    62  			panic(badLdB)
    63  		}
    64  	} else {
    65  		if ldb < max(1, n) {
    66  			panic(badLdB)
    67  		}
    68  	}
    69  	if ldc < max(1, n) {
    70  		panic(badLdC)
    71  	}
    72  
    73  	// Quick return if possible.
    74  	if m == 0 || n == 0 {
    75  		return
    76  	}
    77  
    78  	// For zero matrix size the following slice length checks are trivially satisfied.
    79  	if aTrans {
    80  		if len(a) < (k-1)*lda+m {
    81  			panic(shortA)
    82  		}
    83  	} else {
    84  		if len(a) < (m-1)*lda+k {
    85  			panic(shortA)
    86  		}
    87  	}
    88  	if bTrans {
    89  		if len(b) < (n-1)*ldb+k {
    90  			panic(shortB)
    91  		}
    92  	} else {
    93  		if len(b) < (k-1)*ldb+n {
    94  			panic(shortB)
    95  		}
    96  	}
    97  	if len(c) < (m-1)*ldc+n {
    98  		panic(shortC)
    99  	}
   100  
   101  	// Quick return if possible.
   102  	if (alpha == 0 || k == 0) && beta == 1 {
   103  		return
   104  	}
   105  
   106  	// scale c
   107  	if beta != 1 {
   108  		if beta == 0 {
   109  			for i := 0; i < m; i++ {
   110  				ctmp := c[i*ldc : i*ldc+n]
   111  				for j := range ctmp {
   112  					ctmp[j] = 0
   113  				}
   114  			}
   115  		} else {
   116  			for i := 0; i < m; i++ {
   117  				ctmp := c[i*ldc : i*ldc+n]
   118  				for j := range ctmp {
   119  					ctmp[j] *= beta
   120  				}
   121  			}
   122  		}
   123  	}
   124  
   125  	sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   126  }
   127  
   128  func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   129  	// dgemmParallel computes a parallel matrix multiplication by partitioning
   130  	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
   131  	// In all cases,
   132  	// A = [ 	A_11	A_12 ... 	A_1j
   133  	//			A_21	A_22 ...	A_2j
   134  	//				...
   135  	//			A_i1	A_i2 ...	A_ij]
   136  	//
   137  	// and same for B. All of the submatrix sizes are blockSize×blockSize except
   138  	// at the edges.
   139  	//
   140  	// In all cases, there is one dimension for each matrix along which
   141  	// C must be updated sequentially.
   142  	// Cij = \sum_k Aik Bki,	(A * B)
   143  	// Cij = \sum_k Aki Bkj,	(Aᵀ * B)
   144  	// Cij = \sum_k Aik Bjk,	(A * Bᵀ)
   145  	// Cij = \sum_k Aki Bjk,	(Aᵀ * Bᵀ)
   146  	//
   147  	// This code computes one {i, j} block sequentially along the k dimension,
   148  	// and computes all of the {i, j} blocks concurrently. This
   149  	// partitioning allows Cij to be updated in-place without race-conditions.
   150  	// Instead of launching a goroutine for each possible concurrent computation,
   151  	// a number of worker goroutines are created and channels are used to pass
   152  	// available and completed cases.
   153  	//
   154  	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
   155  	// multiplies, though this code does not copy matrices to attempt to eliminate
   156  	// cache misses.
   157  
   158  	maxKLen := k
   159  	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
   160  	if parBlocks < minParBlock {
   161  		// The matrix multiplication is small in the dimensions where it can be
   162  		// computed concurrently. Just do it in serial.
   163  		sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   164  		return
   165  	}
   166  
   167  	// workerLimit acts a number of maximum concurrent workers,
   168  	// with the limit set to the number of procs available.
   169  	workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
   170  
   171  	// wg is used to wait for all
   172  	var wg sync.WaitGroup
   173  	wg.Add(parBlocks)
   174  	defer wg.Wait()
   175  
   176  	for i := 0; i < m; i += blockSize {
   177  		for j := 0; j < n; j += blockSize {
   178  			workerLimit <- struct{}{}
   179  			go func(i, j int) {
   180  				defer func() {
   181  					wg.Done()
   182  					<-workerLimit
   183  				}()
   184  
   185  				leni := blockSize
   186  				if i+leni > m {
   187  					leni = m - i
   188  				}
   189  				lenj := blockSize
   190  				if j+lenj > n {
   191  					lenj = n - j
   192  				}
   193  
   194  				cSub := sliceView32(c, ldc, i, j, leni, lenj)
   195  
   196  				// Compute A_ik B_kj for all k
   197  				for k := 0; k < maxKLen; k += blockSize {
   198  					lenk := blockSize
   199  					if k+lenk > maxKLen {
   200  						lenk = maxKLen - k
   201  					}
   202  					var aSub, bSub []float32
   203  					if aTrans {
   204  						aSub = sliceView32(a, lda, k, i, lenk, leni)
   205  					} else {
   206  						aSub = sliceView32(a, lda, i, k, leni, lenk)
   207  					}
   208  					if bTrans {
   209  						bSub = sliceView32(b, ldb, j, k, lenj, lenk)
   210  					} else {
   211  						bSub = sliceView32(b, ldb, k, j, lenk, lenj)
   212  					}
   213  					sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
   214  				}
   215  			}(i, j)
   216  		}
   217  	}
   218  }
   219  
   220  // sgemmSerial is serial matrix multiply
   221  func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   222  	switch {
   223  	case !aTrans && !bTrans:
   224  		sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   225  		return
   226  	case aTrans && !bTrans:
   227  		sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   228  		return
   229  	case !aTrans && bTrans:
   230  		sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   231  		return
   232  	case aTrans && bTrans:
   233  		sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   234  		return
   235  	default:
   236  		panic("unreachable")
   237  	}
   238  }
   239  
   240  // sgemmSerial where neither a nor b are transposed
   241  func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   242  	// This style is used instead of the literal [i*stride +j]) is used because
   243  	// approximately 5 times faster as of go 1.3.
   244  	for i := 0; i < m; i++ {
   245  		ctmp := c[i*ldc : i*ldc+n]
   246  		for l, v := range a[i*lda : i*lda+k] {
   247  			tmp := alpha * v
   248  			if tmp != 0 {
   249  				f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
   250  			}
   251  		}
   252  	}
   253  }
   254  
   255  // sgemmSerial where neither a is transposed and b is not
   256  func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   257  	// This style is used instead of the literal [i*stride +j]) is used because
   258  	// approximately 5 times faster as of go 1.3.
   259  	for l := 0; l < k; l++ {
   260  		btmp := b[l*ldb : l*ldb+n]
   261  		for i, v := range a[l*lda : l*lda+m] {
   262  			tmp := alpha * v
   263  			if tmp != 0 {
   264  				ctmp := c[i*ldc : i*ldc+n]
   265  				f32.AxpyUnitary(tmp, btmp, ctmp)
   266  			}
   267  		}
   268  	}
   269  }
   270  
   271  // sgemmSerial where neither a is not transposed and b is
   272  func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   273  	// This style is used instead of the literal [i*stride +j]) is used because
   274  	// approximately 5 times faster as of go 1.3.
   275  	for i := 0; i < m; i++ {
   276  		atmp := a[i*lda : i*lda+k]
   277  		ctmp := c[i*ldc : i*ldc+n]
   278  		for j := 0; j < n; j++ {
   279  			ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
   280  		}
   281  	}
   282  }
   283  
   284  // sgemmSerial where both are transposed
   285  func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   286  	// This style is used instead of the literal [i*stride +j]) is used because
   287  	// approximately 5 times faster as of go 1.3.
   288  	for l := 0; l < k; l++ {
   289  		for i, v := range a[l*lda : l*lda+m] {
   290  			tmp := alpha * v
   291  			if tmp != 0 {
   292  				ctmp := c[i*ldc : i*ldc+n]
   293  				f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
   294  			}
   295  		}
   296  	}
   297  }
   298  
   299  func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
   300  	return a[i*lda+j : (i+r-1)*lda+j+c]
   301  }