gitee.com/quant1x/num@v0.3.2/internal/functions/matrix.go (about)

     1  package functions
     2  
     3  import (
     4  	"gitee.com/quant1x/num/internal/constraints"
     5  	"runtime"
     6  	"sync"
     7  )
     8  
     9  var numCPU int = runtime.NumCPU()
    10  
    11  // matMulParallel runs matrix multiply in parallel by dividing the input rows
    12  func matMulParallel[T constraints.Float](
    13  	dst, x, y []T, m, n, p int,
    14  	vecMul func(dst, x, y []T, m, n int),
    15  	matMul func(dst, x, y []T, m, n, p int),
    16  ) {
    17  	if m < 4 || m*p*n < 100_000 {
    18  		if p == 1 {
    19  			vecMul(dst, x, y, m, n)
    20  		} else {
    21  			matMul(dst, x, y, m, n, p)
    22  		}
    23  		return
    24  	}
    25  
    26  	rowsPerCPU, rem := m/numCPU, m%numCPU
    27  	i := 0
    28  	var wg sync.WaitGroup
    29  	for c := 0; c < numCPU && i < m; c++ {
    30  		numRows := rowsPerCPU
    31  		if c < rem {
    32  			numRows += 1
    33  		}
    34  		dstStart := i * p
    35  		dstEnd := (i + numRows) * p
    36  		xStart := i * n
    37  		xEnd := (i + numRows) * n
    38  
    39  		wg.Add(1)
    40  		go func() {
    41  			if p == 1 {
    42  				vecMul(dst[dstStart:dstEnd], x[xStart:xEnd], y, numRows, n)
    43  			} else {
    44  				matMul(dst[dstStart:dstEnd], x[xStart:xEnd], y, numRows, n, p)
    45  			}
    46  			wg.Done()
    47  		}()
    48  
    49  		i += numRows
    50  	}
    51  	wg.Wait()
    52  }
    53  
    54  func MatMul_Go[T constraints.Float](dst, x, y []T, m, n, p int) {
    55  	for i := 0; i < m; i++ {
    56  		for k := 0; k < n; k++ {
    57  			for j := 0; j < p; j++ { // dst not set to zero
    58  				dst[i*p+j] += x[i*n+k] * y[k*p+j]
    59  			}
    60  		}
    61  	}
    62  }
    63  
    64  func MatMulVec_Go[T constraints.Float](dst, x, y []T, m, n int) {
    65  	for i := 0; i < m; i++ {
    66  		for k := 0; k < n; k++ { // note: dst is not set to zero
    67  			dst[i] += x[i*n+k] * y[k]
    68  		}
    69  	}
    70  }
    71  
    72  func Mat4Mul_Go[T constraints.Float](dst, x, y []T) {
    73  	for i := 0; i < 4; i++ {
    74  		for j := 0; j < 4; j++ {
    75  			dst[i*4+j] = x[i*4]*y[j] + x[i*4+1]*y[1*4+j] +
    76  				x[i*4+2]*y[2*4+j] + x[i*4+3]*y[3*4+j]
    77  		}
    78  	}
    79  }
    80  
    81  func MatMul_Parallel_Go[T constraints.Float](dst, x, y []T, m, n, p int) {
    82  	matMulParallel(dst, x, y, m, n, p, MatMulVec_Go[T], MatMul_Go[T])
    83  }