gitee.com/quant1x/num@v0.3.2/asm/src/floats_avx.c (about)

     1  // Copyright 2022 gorse Project Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  #include <immintrin.h>
    16  #include <stdint.h>
    17  
    18  void _mm256_mul_const_add_to(float *a, float *b, float *c, int64_t n)
    19  {
    20      int epoch = n / 8;
    21      int remain = n % 8;
    22      for (int i = 0; i < epoch; i++)
    23      {
    24          __m256 v1 = _mm256_loadu_ps(a);
    25          __m256 v2 = _mm256_broadcast_ss(b);
    26          __m256 v3 = _mm256_loadu_ps(c);
    27          __m256 v = _mm256_add_ps(_mm256_mul_ps(v1, v2), v3);
    28          _mm256_storeu_ps(c, v);
    29          a += 8;
    30          c += 8;
    31      }
    32      for (int i = 0; i < remain; i++)
    33      {
    34          c[i] += a[i] * b[0];
    35      }
    36  }
    37  
    38  void _mm256_mul_const_to(float *a, float *b, float *c, int64_t n)
    39  {
    40      int epoch = n / 8;
    41      int remain = n % 8;
    42      for (int i = 0; i < epoch; i++)
    43      {
    44          __m256 v1 = _mm256_loadu_ps(a);
    45          __m256 v2 = _mm256_broadcast_ss(b);
    46          __m256 v = _mm256_mul_ps(v1, v2);
    47          _mm256_storeu_ps(c, v);
    48          a += 8;
    49          c += 8;
    50      }
    51      for (int i = 0; i < remain; i++)
    52      {
    53          c[i] = a[i] * b[0];
    54      }
    55  }
    56  
    57  void _mm256_mul_const(float *a, float *b, int64_t n)
    58  {
    59      int epoch = n / 8;
    60      int remain = n % 8;
    61      for (int i = 0; i < epoch; i++)
    62      {
    63          __m256 v1 = _mm256_loadu_ps(a);
    64          __m256 v2 = _mm256_broadcast_ss(b);
    65          __m256 v = _mm256_mul_ps(v1, v2);
    66          _mm256_storeu_ps(a, v);
    67          a += 8;
    68      }
    69      for (int i = 0; i < remain; i++)
    70      {
    71          a[i] *= b[0];
    72      }
    73  }
    74  
    75  void _mm256_mul_to(float *a, float *b, float *c, int64_t n)
    76  {
    77      int epoch = n / 8;
    78      int remain = n % 8;
    79      for (int i = 0; i < epoch; i++)
    80      {
    81          __m256 v1 = _mm256_loadu_ps(a);
    82          __m256 v2 = _mm256_loadu_ps(b);
    83          __m256 v = _mm256_mul_ps(v1, v2);
    84          _mm256_storeu_ps(c, v);
    85          a += 8;
    86          b += 8;
    87          c += 8;
    88      }
    89      for (int i = 0; i < remain; i++)
    90      {
    91          c[i] = a[i] * b[i];
    92      }
    93  }
    94  
    95  void _mm256_dot(float *a, float *b, int64_t n, float *ret)
    96  {
    97      int epoch = n / 8;
    98      int remain = n % 8;
    99      __m256 s;
   100      if (epoch > 0)
   101      {
   102          __m256 v1 = _mm256_loadu_ps(a);
   103          __m256 v2 = _mm256_loadu_ps(b);
   104          s = _mm256_mul_ps(v1, v2);
   105          a += 8;
   106          b += 8;
   107      }
   108      for (int i = 1; i < epoch; i++)
   109      {
   110          __m256 v1 = _mm256_loadu_ps(a);
   111          __m256 v2 = _mm256_loadu_ps(b);
   112          s = _mm256_add_ps(_mm256_mul_ps(v1, v2), s);
   113          a += 8;
   114          b += 8;
   115      }
   116      __m128 s7_6_5_4 = _mm256_extractf128_ps(s, 1);
   117      __m128 s3_2_1_0 = _mm256_castps256_ps128(s);
   118      __m128 s37_26_15_04 = _mm_add_ps(s7_6_5_4, s3_2_1_0);
   119      __m128 sxx_15_04 = s37_26_15_04;
   120      __m128 sxx_37_26 = _mm_movehl_ps(s37_26_15_04, s37_26_15_04);
   121      const __m128 sxx_1357_0246 = _mm_add_ps(sxx_15_04, sxx_37_26);
   122      const __m128 sxxx_0246 = sxx_1357_0246;
   123      const __m128 sxxx_1357 = _mm_shuffle_ps(sxx_1357_0246, sxx_1357_0246, 0x1);
   124      __m128 sxxx_01234567 = _mm_add_ss(sxxx_0246, sxxx_1357);
   125      *ret = _mm_cvtss_f32(sxxx_01234567);
   126      for (int i = 0; i < remain; i++)
   127      {
   128          *ret += a[i] * b[i];
   129      }
   130  }