gitee.com/quant1x/gox@v1.7.6/num/asm/_cpp/matrix.cpp (about) 1 #include <cstddef> 2 #include <algorithm> 3 4 template<typename T> 5 struct Mat4 { 6 T m[4][4]; 7 }; 8 9 template<typename T> 10 void Mat4Mul(Mat4<T>* __restrict dst, Mat4<T>* __restrict x, Mat4<T>* __restrict y) { 11 for (int i = 0; i < 4; i++) { 12 for (int j = 0; j < 4; j++) { 13 dst->m[i][j] = x->m[i][0] * y->m[0][j] + x->m[i][1] * y->m[1][j] + 14 x->m[i][2] * y->m[2][j] + x->m[i][3] * y->m[3][j]; 15 } 16 } 17 } 18 19 template <typename T> 20 void MatMul(T* __restrict dst, T* __restrict x, T* __restrict y, size_t m, size_t n, size_t p) { 21 for (size_t i = 0; i < m; i++) { 22 for (size_t k = 0; k < n; k++) { 23 for (size_t j = 0; j < p; j++) { // note: dst is not set to zero 24 dst[i*p + j] += x[i*n + k] * y[k*p + j]; 25 } 26 } 27 } 28 } 29 30 template <typename T> 31 void MatMulVec(T* __restrict dst, T* __restrict x, T* __restrict y, size_t m, size_t n) { 32 for (size_t i = 0; i < m; i++) { 33 for (size_t k = 0; k < n; k++) { // note: dst is not set to zero 34 dst[i] += x[i*n + k] * y[k]; 35 } 36 } 37 } 38 39 template <typename T> 40 void MatMulTiled(T* __restrict dst, T* __restrict x, T* __restrict y, size_t m, size_t n, size_t p) { 41 size_t TI = 8; 42 size_t TJ = 256; 43 size_t TK = 256; 44 for (size_t I = 0; I < m + TI - 1; I += TI) { 45 for (size_t J = 0; J < p + TJ - 1; J += TJ) { 46 for (size_t K = 0; K < n + TK - 1; K += TK) { 47 const int maxI = std::min(I + TI, m); 48 const int maxK = std::min(K + TK, n); 49 const int maxJ = std::min(J + TJ, p); 50 for (size_t i = I; i < maxI; i++) { 51 for (int k = K; k < maxK; k++) { 52 for (int j = J; j < maxJ; j++) { 53 dst[i*p + j] += x[i*n + k] * y[k*p + j]; 54 } 55 } 56 } 57 } 58 } 59 } 60 } 61 62 void Mat4Mul_F64_V(double* dst, double* x, double* y) { 63 Mat4Mul((Mat4<double>*)dst, (Mat4<double>*)x, (Mat4<double>*)y); 64 } 65 66 void Mat4Mul_F32_V(float* dst, float* x, float* y) { 67 Mat4Mul((Mat4<float>*)dst, (Mat4<float>*)x, (Mat4<float>*)y); 68 } 69 70 void MatMul_F64_V(double* dst, double* x, double* y, size_t m, size_t n, size_t p) { 71 MatMul(dst, x, y, m, n, p); 72 } 73 74 void MatMul_F32_V(float* dst, float* x, float* y, size_t m, size_t n, size_t p) { 75 MatMul(dst, x, y, m, n, p); 76 } 77 78 void MatMulVec_F64_V(double* dst, double* x, double* y, size_t m, size_t n) { 79 MatMulVec(dst, x, y, m, n); 80 } 81 82 void MatMulVec_F32_V(float* dst, float* x, float* y, size_t m, size_t n) { 83 MatMulVec(dst, x, y, m, n); 84 } 85 86 void MatMulTiled_F64_V(double* dst, double* x, double* y, size_t m, size_t n, size_t p) { 87 MatMulTiled(dst, x, y, m, n, p); 88 } 89 90 void MatMulTiled_F32_V(float* dst, float* x, float* y, size_t m, size_t n, size_t p) { 91 MatMulTiled(dst, x, y, m, n, p); 92 }