github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/gonum/sgemm.go (about)

     1  // Code generated by "go generate github.com/jingcheng-WU/gonum/blas/gonum”; DO NOT EDIT.
     2  
     3  // Copyright ©2014 The Gonum Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  
     7  package gonum
     8  
     9  import (
    10  	"runtime"
    11  	"sync"
    12  
    13  	"github.com/jingcheng-WU/gonum/blas"
    14  	"github.com/jingcheng-WU/gonum/internal/asm/f32"
    15  )
    16  
    17  // Sgemm performs one of the matrix-matrix operations
    18  //  C = alpha * A * B + beta * C
    19  //  C = alpha * Aᵀ * B + beta * C
    20  //  C = alpha * A * Bᵀ + beta * C
    21  //  C = alpha * Aᵀ * Bᵀ + beta * C
    22  // where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
    23  // an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
    24  // B are transposed.
    25  //
    26  // Float32 implementations are autogenerated and not directly tested.
    27  func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
    28  	switch tA {
    29  	default:
    30  		panic(badTranspose)
    31  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    32  	}
    33  	switch tB {
    34  	default:
    35  		panic(badTranspose)
    36  	case blas.NoTrans, blas.Trans, blas.ConjTrans:
    37  	}
    38  	if m < 0 {
    39  		panic(mLT0)
    40  	}
    41  	if n < 0 {
    42  		panic(nLT0)
    43  	}
    44  	if k < 0 {
    45  		panic(kLT0)
    46  	}
    47  	aTrans := tA == blas.Trans || tA == blas.ConjTrans
    48  	if aTrans {
    49  		if lda < max(1, m) {
    50  			panic(badLdA)
    51  		}
    52  	} else {
    53  		if lda < max(1, k) {
    54  			panic(badLdA)
    55  		}
    56  	}
    57  	bTrans := tB == blas.Trans || tB == blas.ConjTrans
    58  	if bTrans {
    59  		if ldb < max(1, k) {
    60  			panic(badLdB)
    61  		}
    62  	} else {
    63  		if ldb < max(1, n) {
    64  			panic(badLdB)
    65  		}
    66  	}
    67  	if ldc < max(1, n) {
    68  		panic(badLdC)
    69  	}
    70  
    71  	// Quick return if possible.
    72  	if m == 0 || n == 0 {
    73  		return
    74  	}
    75  
    76  	// For zero matrix size the following slice length checks are trivially satisfied.
    77  	if aTrans {
    78  		if len(a) < (k-1)*lda+m {
    79  			panic(shortA)
    80  		}
    81  	} else {
    82  		if len(a) < (m-1)*lda+k {
    83  			panic(shortA)
    84  		}
    85  	}
    86  	if bTrans {
    87  		if len(b) < (n-1)*ldb+k {
    88  			panic(shortB)
    89  		}
    90  	} else {
    91  		if len(b) < (k-1)*ldb+n {
    92  			panic(shortB)
    93  		}
    94  	}
    95  	if len(c) < (m-1)*ldc+n {
    96  		panic(shortC)
    97  	}
    98  
    99  	// Quick return if possible.
   100  	if (alpha == 0 || k == 0) && beta == 1 {
   101  		return
   102  	}
   103  
   104  	// scale c
   105  	if beta != 1 {
   106  		if beta == 0 {
   107  			for i := 0; i < m; i++ {
   108  				ctmp := c[i*ldc : i*ldc+n]
   109  				for j := range ctmp {
   110  					ctmp[j] = 0
   111  				}
   112  			}
   113  		} else {
   114  			for i := 0; i < m; i++ {
   115  				ctmp := c[i*ldc : i*ldc+n]
   116  				for j := range ctmp {
   117  					ctmp[j] *= beta
   118  				}
   119  			}
   120  		}
   121  	}
   122  
   123  	sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   124  }
   125  
   126  func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   127  	// dgemmParallel computes a parallel matrix multiplication by partitioning
   128  	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
   129  	// In all cases,
   130  	// A = [ 	A_11	A_12 ... 	A_1j
   131  	//			A_21	A_22 ...	A_2j
   132  	//				...
   133  	//			A_i1	A_i2 ...	A_ij]
   134  	//
   135  	// and same for B. All of the submatrix sizes are blockSize×blockSize except
   136  	// at the edges.
   137  	//
   138  	// In all cases, there is one dimension for each matrix along which
   139  	// C must be updated sequentially.
   140  	// Cij = \sum_k Aik Bki,	(A * B)
   141  	// Cij = \sum_k Aki Bkj,	(Aᵀ * B)
   142  	// Cij = \sum_k Aik Bjk,	(A * Bᵀ)
   143  	// Cij = \sum_k Aki Bjk,	(Aᵀ * Bᵀ)
   144  	//
   145  	// This code computes one {i, j} block sequentially along the k dimension,
   146  	// and computes all of the {i, j} blocks concurrently. This
   147  	// partitioning allows Cij to be updated in-place without race-conditions.
   148  	// Instead of launching a goroutine for each possible concurrent computation,
   149  	// a number of worker goroutines are created and channels are used to pass
   150  	// available and completed cases.
   151  	//
   152  	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
   153  	// multiplies, though this code does not copy matrices to attempt to eliminate
   154  	// cache misses.
   155  
   156  	maxKLen := k
   157  	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
   158  	if parBlocks < minParBlock {
   159  		// The matrix multiplication is small in the dimensions where it can be
   160  		// computed concurrently. Just do it in serial.
   161  		sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   162  		return
   163  	}
   164  
   165  	// workerLimit acts a number of maximum concurrent workers,
   166  	// with the limit set to the number of procs available.
   167  	workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
   168  
   169  	// wg is used to wait for all
   170  	var wg sync.WaitGroup
   171  	wg.Add(parBlocks)
   172  	defer wg.Wait()
   173  
   174  	for i := 0; i < m; i += blockSize {
   175  		for j := 0; j < n; j += blockSize {
   176  			workerLimit <- struct{}{}
   177  			go func(i, j int) {
   178  				defer func() {
   179  					wg.Done()
   180  					<-workerLimit
   181  				}()
   182  
   183  				leni := blockSize
   184  				if i+leni > m {
   185  					leni = m - i
   186  				}
   187  				lenj := blockSize
   188  				if j+lenj > n {
   189  					lenj = n - j
   190  				}
   191  
   192  				cSub := sliceView32(c, ldc, i, j, leni, lenj)
   193  
   194  				// Compute A_ik B_kj for all k
   195  				for k := 0; k < maxKLen; k += blockSize {
   196  					lenk := blockSize
   197  					if k+lenk > maxKLen {
   198  						lenk = maxKLen - k
   199  					}
   200  					var aSub, bSub []float32
   201  					if aTrans {
   202  						aSub = sliceView32(a, lda, k, i, lenk, leni)
   203  					} else {
   204  						aSub = sliceView32(a, lda, i, k, leni, lenk)
   205  					}
   206  					if bTrans {
   207  						bSub = sliceView32(b, ldb, j, k, lenj, lenk)
   208  					} else {
   209  						bSub = sliceView32(b, ldb, k, j, lenk, lenj)
   210  					}
   211  					sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
   212  				}
   213  			}(i, j)
   214  		}
   215  	}
   216  }
   217  
   218  // sgemmSerial is serial matrix multiply
   219  func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   220  	switch {
   221  	case !aTrans && !bTrans:
   222  		sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   223  		return
   224  	case aTrans && !bTrans:
   225  		sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   226  		return
   227  	case !aTrans && bTrans:
   228  		sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   229  		return
   230  	case aTrans && bTrans:
   231  		sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
   232  		return
   233  	default:
   234  		panic("unreachable")
   235  	}
   236  }
   237  
   238  // sgemmSerial where neither a nor b are transposed
   239  func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   240  	// This style is used instead of the literal [i*stride +j]) is used because
   241  	// approximately 5 times faster as of go 1.3.
   242  	for i := 0; i < m; i++ {
   243  		ctmp := c[i*ldc : i*ldc+n]
   244  		for l, v := range a[i*lda : i*lda+k] {
   245  			tmp := alpha * v
   246  			if tmp != 0 {
   247  				f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
   248  			}
   249  		}
   250  	}
   251  }
   252  
   253  // sgemmSerial where neither a is transposed and b is not
   254  func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   255  	// This style is used instead of the literal [i*stride +j]) is used because
   256  	// approximately 5 times faster as of go 1.3.
   257  	for l := 0; l < k; l++ {
   258  		btmp := b[l*ldb : l*ldb+n]
   259  		for i, v := range a[l*lda : l*lda+m] {
   260  			tmp := alpha * v
   261  			if tmp != 0 {
   262  				ctmp := c[i*ldc : i*ldc+n]
   263  				f32.AxpyUnitary(tmp, btmp, ctmp)
   264  			}
   265  		}
   266  	}
   267  }
   268  
   269  // sgemmSerial where neither a is not transposed and b is
   270  func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   271  	// This style is used instead of the literal [i*stride +j]) is used because
   272  	// approximately 5 times faster as of go 1.3.
   273  	for i := 0; i < m; i++ {
   274  		atmp := a[i*lda : i*lda+k]
   275  		ctmp := c[i*ldc : i*ldc+n]
   276  		for j := 0; j < n; j++ {
   277  			ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
   278  		}
   279  	}
   280  }
   281  
   282  // sgemmSerial where both are transposed
   283  func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
   284  	// This style is used instead of the literal [i*stride +j]) is used because
   285  	// approximately 5 times faster as of go 1.3.
   286  	for l := 0; l < k; l++ {
   287  		for i, v := range a[l*lda : l*lda+m] {
   288  			tmp := alpha * v
   289  			if tmp != 0 {
   290  				ctmp := c[i*ldc : i*ldc+n]
   291  				f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
   292  			}
   293  		}
   294  	}
   295  }
   296  
   297  func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
   298  	return a[i*lda+j : (i+r-1)*lda+j+c]
   299  }