gitee.com/quant1x/num@v0.3.2/internal/functions/matrix.go (about) 1 package functions 2 3 import ( 4 "gitee.com/quant1x/num/internal/constraints" 5 "runtime" 6 "sync" 7 ) 8 9 var numCPU int = runtime.NumCPU() 10 11 // matMulParallel runs matrix multiply in parallel by dividing the input rows 12 func matMulParallel[T constraints.Float]( 13 dst, x, y []T, m, n, p int, 14 vecMul func(dst, x, y []T, m, n int), 15 matMul func(dst, x, y []T, m, n, p int), 16 ) { 17 if m < 4 || m*p*n < 100_000 { 18 if p == 1 { 19 vecMul(dst, x, y, m, n) 20 } else { 21 matMul(dst, x, y, m, n, p) 22 } 23 return 24 } 25 26 rowsPerCPU, rem := m/numCPU, m%numCPU 27 i := 0 28 var wg sync.WaitGroup 29 for c := 0; c < numCPU && i < m; c++ { 30 numRows := rowsPerCPU 31 if c < rem { 32 numRows += 1 33 } 34 dstStart := i * p 35 dstEnd := (i + numRows) * p 36 xStart := i * n 37 xEnd := (i + numRows) * n 38 39 wg.Add(1) 40 go func() { 41 if p == 1 { 42 vecMul(dst[dstStart:dstEnd], x[xStart:xEnd], y, numRows, n) 43 } else { 44 matMul(dst[dstStart:dstEnd], x[xStart:xEnd], y, numRows, n, p) 45 } 46 wg.Done() 47 }() 48 49 i += numRows 50 } 51 wg.Wait() 52 } 53 54 func MatMul_Go[T constraints.Float](dst, x, y []T, m, n, p int) { 55 for i := 0; i < m; i++ { 56 for k := 0; k < n; k++ { 57 for j := 0; j < p; j++ { // dst not set to zero 58 dst[i*p+j] += x[i*n+k] * y[k*p+j] 59 } 60 } 61 } 62 } 63 64 func MatMulVec_Go[T constraints.Float](dst, x, y []T, m, n int) { 65 for i := 0; i < m; i++ { 66 for k := 0; k < n; k++ { // note: dst is not set to zero 67 dst[i] += x[i*n+k] * y[k] 68 } 69 } 70 } 71 72 func Mat4Mul_Go[T constraints.Float](dst, x, y []T) { 73 for i := 0; i < 4; i++ { 74 for j := 0; j < 4; j++ { 75 dst[i*4+j] = x[i*4]*y[j] + x[i*4+1]*y[1*4+j] + 76 x[i*4+2]*y[2*4+j] + x[i*4+3]*y[3*4+j] 77 } 78 } 79 } 80 81 func MatMul_Parallel_Go[T constraints.Float](dst, x, y []T, m, n, p int) { 82 matMulParallel(dst, x, y, m, n, p, MatMulVec_Go[T], MatMul_Go[T]) 83 }