gitee.com/quant1x/gox@v1.7.6/num/asm/_cpp/matrix.cpp (about)

     1  #include <cstddef>
     2  #include <algorithm>
     3  
     4  template<typename T>
     5  struct Mat4 {
     6      T m[4][4];
     7  };
     8  
     9  template<typename T>
    10  void Mat4Mul(Mat4<T>* __restrict dst, Mat4<T>* __restrict x, Mat4<T>* __restrict y) {
    11      for (int i = 0; i < 4; i++) {
    12          for (int j = 0; j < 4; j++) {
    13              dst->m[i][j] = x->m[i][0] * y->m[0][j] + x->m[i][1] * y->m[1][j] +
    14                             x->m[i][2] * y->m[2][j] + x->m[i][3] * y->m[3][j];
    15          }
    16      }
    17  }
    18  
    19  template <typename T>
    20  void MatMul(T* __restrict dst, T* __restrict x, T* __restrict y, size_t m, size_t n, size_t p) {
    21      for (size_t i = 0; i < m; i++) {
    22          for (size_t k = 0; k < n; k++) {
    23              for (size_t j = 0; j < p; j++) {  // note: dst is not set to zero
    24                  dst[i*p + j] += x[i*n + k] * y[k*p + j];
    25              }
    26          }
    27      }
    28  }
    29  
    30  template <typename T>
    31  void MatMulVec(T* __restrict dst, T* __restrict x, T* __restrict y, size_t m, size_t n) {
    32      for (size_t i = 0; i < m; i++) {
    33          for (size_t k = 0; k < n; k++) {  // note: dst is not set to zero
    34              dst[i] += x[i*n + k] * y[k];
    35          }
    36      }
    37  }
    38  
    39  template <typename T>
    40  void MatMulTiled(T* __restrict dst, T* __restrict x, T* __restrict y, size_t m, size_t n, size_t p) {
    41      size_t TI = 8;
    42      size_t TJ = 256;
    43      size_t TK = 256;
    44      for (size_t I = 0; I < m + TI - 1; I += TI) {
    45          for (size_t J = 0; J < p + TJ - 1; J += TJ) {
    46              for (size_t K = 0; K < n + TK - 1; K += TK) {
    47                  const int maxI = std::min(I + TI, m);
    48                  const int maxK = std::min(K + TK, n);
    49                  const int maxJ = std::min(J + TJ, p);
    50                  for (size_t i = I; i < maxI; i++) {
    51                      for (int k = K; k < maxK; k++) {
    52                          for (int j = J; j < maxJ; j++) {
    53                              dst[i*p + j] += x[i*n + k] * y[k*p + j];
    54                          }
    55                      }
    56                  }
    57              }
    58          }
    59      }
    60  }
    61  
    62  void Mat4Mul_F64_V(double* dst, double* x, double* y) {
    63      Mat4Mul((Mat4<double>*)dst, (Mat4<double>*)x, (Mat4<double>*)y);
    64  }
    65  
    66  void Mat4Mul_F32_V(float* dst, float* x, float* y) {
    67      Mat4Mul((Mat4<float>*)dst, (Mat4<float>*)x, (Mat4<float>*)y);
    68  }
    69  
    70  void MatMul_F64_V(double* dst, double* x, double* y, size_t m, size_t n, size_t p) {
    71      MatMul(dst, x, y, m, n, p);
    72  }
    73  
    74  void MatMul_F32_V(float* dst, float* x, float* y, size_t m, size_t n, size_t p) {
    75      MatMul(dst, x, y, m, n, p);
    76  }
    77  
    78  void MatMulVec_F64_V(double* dst, double* x, double* y, size_t m, size_t n) {
    79      MatMulVec(dst, x, y, m, n);
    80  }
    81  
    82  void MatMulVec_F32_V(float* dst, float* x, float* y, size_t m, size_t n) {
    83      MatMulVec(dst, x, y, m, n);
    84  }
    85  
    86  void MatMulTiled_F64_V(double* dst, double* x, double* y, size_t m, size_t n, size_t p) {
    87      MatMulTiled(dst, x, y, m, n, p);
    88  }
    89  
    90  void MatMulTiled_F32_V(float* dst, float* x, float* y, size_t m, size_t n, size_t p) {
    91      MatMulTiled(dst, x, y, m, n, p);
    92  }