gitee.com/quant1x/num@v0.3.2/asm/src/floats_avx512.c (about)

     1  // Copyright 2022 gorse Project Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <immintrin.h>
    16  #include <stdint.h>
    17  
    18  void _mm512_mul_const_add_to(float *a, float *b, float *c, int64_t n)
    19  {
    20      int epoch = n / 16;
    21      int remain = n % 16;
    22      for (int i = 0; i < epoch; i++)
    23      {
    24          __m512 v1 = _mm512_loadu_ps(a);
    25          __m512 v2 = _mm512_set1_ps(*b);
    26          __m512 v3 = _mm512_loadu_ps(c);
    27          __m512 v = _mm512_fmadd_ps(v1, v2, v3);
    28          _mm512_storeu_ps(c, v);
    29          a += 16;
    30          c += 16;
    31      }
    32      if (remain >= 8)
    33      {
    34          __m256 v1 = _mm256_loadu_ps(a);
    35          __m256 v2 = _mm256_broadcast_ss(b);
    36          __m256 v3 = _mm256_loadu_ps(c);
    37          __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3);
    38          _mm256_storeu_ps(c, v);
    39          a += 8;
    40          c += 8;
    41          remain -= 8;
    42      }
    43      for (int i = 0; i < remain; i++)
    44      {
    45          c[i] += a[i] * b[0];
    46      }
    47  }
    48  
    49  void _mm512_mul_const_to(float *a, float *b, float *c, int64_t n)
    50  {
    51      int epoch = n / 16;
    52      int remain = n % 16;
    53      for (int i = 0; i < epoch; i++)
    54      {
    55          __m512 v1 = _mm512_loadu_ps(a);
    56          __m512 v2 = _mm512_set1_ps(*b);
    57          __m512 v = _mm512_mul_ps(v1, v2);
    58          _mm512_storeu_ps(c, v);
    59          a += 16;
    60          c += 16;
    61      }
    62      if (remain >= 8)
    63      {
    64          __m256 v1 = _mm256_loadu_ps(a);
    65          __m256 v2 = _mm256_broadcast_ss(b);
    66          __m256 v = _mm256_mul_ps(v1, v2);
    67          _mm256_storeu_ps(c, v);
    68          a += 8;
    69          c += 8;
    70          remain -= 8;
    71      }
    72      for (int i = 0; i < remain; i++)
    73      {
    74          c[i] = a[i] * b[0];
    75      }
    76  }
    77  
    78  void _mm512_mul_const(float *a, float *b, int64_t n)
    79  {
    80      int epoch = n / 16;
    81      int remain = n % 16;
    82      for (int i = 0; i < epoch; i++)
    83      {
    84          __m512 v1 = _mm512_loadu_ps(a);
    85          __m512 v2 = _mm512_set1_ps(*b);
    86          __m512 v = _mm512_mul_ps(v1, v2);
    87          _mm512_storeu_ps(a, v);
    88          a += 16;
    89      }
    90      if (remain >= 8)
    91      {
    92          __m256 v1 = _mm256_loadu_ps(a);
    93          __m256 v2 = _mm256_broadcast_ss(b);
    94          __m256 v = _mm256_mul_ps(v1, v2);
    95          _mm256_storeu_ps(a, v);
    96          a += 8;
    97          remain -= 8;
    98      }
    99      for (int i = 0; i < remain; i++)
   100      {
   101          a[i] *= b[0];
   102      }
   103  }
   104  
   105  void _mm512_mul_to(float *a, float *b, float *c, int64_t n)
   106  {
   107      int epoch = n / 16;
   108      int remain = n % 16;
   109      for (int i = 0; i < epoch; i++)
   110      {
   111          __m512 v1 = _mm512_loadu_ps(a);
   112          __m512 v2 = _mm512_loadu_ps(b);
   113          __m512 v = _mm512_mul_ps(v1, v2);
   114          _mm512_storeu_ps(c, v);
   115          a += 16;
   116          b += 16;
   117          c += 16;
   118      }
   119      if (remain >= 8)
   120      {
   121          __m256 v1 = _mm256_loadu_ps(a);
   122          __m256 v2 = _mm256_loadu_ps(b);
   123          __m256 v = _mm256_mul_ps(v1, v2);
   124          _mm256_storeu_ps(c, v);
   125          a += 8;
   126          b += 8;
   127          c += 8;
   128          remain -= 8;
   129      }
   130      for (int i = 0; i < remain; i++)
   131      {
   132          c[i] = a[i] * b[i];
   133      }
   134  }
   135  
   136  void _mm512_dot(float *a, float *b, int64_t n, float *ret)
   137  {
   138      int epoch = n / 16;
   139      int remain = n % 16;
   140      __m512 s;
   141      if (epoch > 0)
   142      {
   143          __m512 v1 = _mm512_loadu_ps(a);
   144          __m512 v2 = _mm512_loadu_ps(b);
   145          s = _mm512_mul_ps(v1, v2);
   146          a += 16;
   147          b += 16;
   148      }
   149      for (int i = 1; i < epoch; i++)
   150      {
   151          __m512 v1 = _mm512_loadu_ps(a);
   152          __m512 v2 = _mm512_loadu_ps(b);
   153          s = _mm512_fmadd_ps(v1, v2, s);
   154          a += 16;
   155          b += 16;
   156      }
   157      __m256 sf_e_d_c_b_a_9_8 = _mm512_extractf32x8_ps(s, 1);
   158      __m256 s7_6_5_4_3_2_1_0 = _mm512_castps512_ps256(s);
   159      __m256 s7f_6e_5d_4c_3b_2a_19_08 = _mm256_add_ps(sf_e_d_c_b_a_9_8, s7_6_5_4_3_2_1_0);
   160      __m128 s7f_6e_5d_4c = _mm256_extractf128_ps(s7f_6e_5d_4c_3b_2a_19_08, 1);
   161      __m128 s3b_2a_19_08 = _mm256_castps256_ps128(s7f_6e_5d_4c_3b_2a_19_08);
   162      __m128 s37bf_26ae_159d_048c = _mm_add_ps(s7f_6e_5d_4c, s3b_2a_19_08);
   163      __m128 sxx_159d_048c = s37bf_26ae_159d_048c;
   164      __m128 sxx_37bf_26ae = _mm_movehl_ps(sxx_159d_048c, s37bf_26ae_159d_048c);
   165      const __m128 sxx_13579bdf_02468ace = _mm_add_ps(sxx_159d_048c, sxx_37bf_26ae);
   166      const __m128 sxxx_02468ace = sxx_13579bdf_02468ace;
   167      const __m128 sxxx_13579bdf = _mm_shuffle_ps(sxx_13579bdf_02468ace, sxx_13579bdf_02468ace, 0x1);
   168      __m128 sxxx_0123456789abcdef = _mm_add_ss(sxxx_02468ace, sxxx_13579bdf);
   169      *ret = _mm_cvtss_f32(sxxx_0123456789abcdef);
   170  
   171      if (remain >= 8)
   172      {
   173          __m256 s;
   174          __m256 v1 = _mm256_loadu_ps(a);
   175          __m256 v2 = _mm256_loadu_ps(b);
   176          s = _mm256_mul_ps(v1, v2);
   177          a += 8;
   178          b += 8;
   179          __m128 s7_6_5_4 = _mm256_extractf128_ps(s, 1);
   180          __m128 s3_2_1_0 = _mm256_castps256_ps128(s);
   181          __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0);
   182          __m128 sxx_15_04 = s37_26_15_04;
   183          __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04);
   184          const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26);
   185          const __m128 sxxx_0246 = sxx_1357_0246;
   186          const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1);
   187          __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357);
   188          *ret += _mm_cvtss_f32(sxxx_01234567);
   189          remain -= 8;
   190      }
   191  
   192      for (int i = 0; i < remain; i++)
   193      {
   194          *ret += a[i] * b[i];
   195      }
   196  }