github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/dot_avx512_amd64.c (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2023 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  #include <immintrin.h>
    13  
    14  void dot_512(float *a, float *b, float *res, long *len)
    15  {
    16      int n = *len;
    17      float sum = 0;
    18  
    19      // fast path for small dimensions
    20      if (n < 8)
    21      {
    22          do
    23          {
    24              sum += a[0] * b[0];
    25              n--;
    26              a++;
    27              b++;
    28          } while (n);
    29  
    30          *res = sum;
    31          return;
    32      }
    33  
    34      // Create 4 registers to store the results
    35      __m256 acc[4];
    36      acc[0] = _mm256_setzero_ps();
    37      acc[1] = _mm256_setzero_ps();
    38      acc[2] = _mm256_setzero_ps();
    39      acc[3] = _mm256_setzero_ps();
    40  
    41      if (n >= 128)
    42      {
    43          // create 8 registers
    44          __m512 acc5[8];
    45          acc5[0] = _mm512_setzero_ps();
    46          acc5[1] = _mm512_setzero_ps();
    47          acc5[2] = _mm512_setzero_ps();
    48          acc5[3] = _mm512_setzero_ps();
    49          acc5[4] = _mm512_setzero_ps();
    50          acc5[5] = _mm512_setzero_ps();
    51          acc5[6] = _mm512_setzero_ps();
    52          acc5[7] = _mm512_setzero_ps();
    53  
    54          // Process 128 floats at a time
    55          do
    56          {
    57              __m512 a_vec0 = _mm512_loadu_ps(a);
    58              __m512 a_vec1 = _mm512_loadu_ps(a + 16);
    59              __m512 a_vec2 = _mm512_loadu_ps(a + 32);
    60              __m512 a_vec3 = _mm512_loadu_ps(a + 48);
    61              __m512 a_vec4 = _mm512_loadu_ps(a + 64);
    62              __m512 a_vec5 = _mm512_loadu_ps(a + 80);
    63              __m512 a_vec6 = _mm512_loadu_ps(a + 96);
    64              __m512 a_vec7 = _mm512_loadu_ps(a + 112);
    65  
    66              __m512 b_vec0 = _mm512_loadu_ps(b);
    67              __m512 b_vec1 = _mm512_loadu_ps(b + 16);
    68              __m512 b_vec2 = _mm512_loadu_ps(b + 32);
    69              __m512 b_vec3 = _mm512_loadu_ps(b + 48);
    70              __m512 b_vec4 = _mm512_loadu_ps(b + 64);
    71              __m512 b_vec5 = _mm512_loadu_ps(b + 80);
    72              __m512 b_vec6 = _mm512_loadu_ps(b + 96);
    73              __m512 b_vec7 = _mm512_loadu_ps(b + 112);
    74  
    75              acc5[0] = _mm512_fmadd_ps(a_vec0, b_vec0, acc5[0]);
    76              acc5[1] = _mm512_fmadd_ps(a_vec1, b_vec1, acc5[1]);
    77              acc5[2] = _mm512_fmadd_ps(a_vec2, b_vec2, acc5[2]);
    78              acc5[3] = _mm512_fmadd_ps(a_vec3, b_vec3, acc5[3]);
    79              acc5[4] = _mm512_fmadd_ps(a_vec4, b_vec4, acc5[4]);
    80              acc5[5] = _mm512_fmadd_ps(a_vec5, b_vec5, acc5[5]);
    81              acc5[6] = _mm512_fmadd_ps(a_vec6, b_vec6, acc5[6]);
    82              acc5[7] = _mm512_fmadd_ps(a_vec7, b_vec7, acc5[7]);
    83  
    84              n -= 128;
    85              a += 128;
    86              b += 128;
    87          } while (n >= 128);
    88  
    89          acc5[0] = _mm512_add_ps(acc5[1], acc5[0]);
    90          acc5[2] = _mm512_add_ps(acc5[3], acc5[2]);
    91          acc5[4] = _mm512_add_ps(acc5[5], acc5[4]);
    92          acc5[6] = _mm512_add_ps(acc5[7], acc5[6]);
    93          acc5[0] = _mm512_add_ps(acc5[2], acc5[0]);
    94          acc5[4] = _mm512_add_ps(acc5[6], acc5[4]);
    95          acc5[0] = _mm512_add_ps(acc5[4], acc5[0]);
    96  
    97          __m256 low = _mm512_castps512_ps256(acc5[0]);
    98          __m256 high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(acc5[0]), 1));
    99  
   100          acc[0] = _mm256_add_ps(low, acc[0]);
   101          acc[0] = _mm256_add_ps(high, acc[0]);
   102  
   103          if (!n)
   104          {
   105              // Reduce and store the result
   106              acc[0] = _mm256_add_ps(acc[1], acc[0]);
   107              acc[2] = _mm256_add_ps(acc[3], acc[2]);
   108              acc[0] = _mm256_add_ps(acc[2], acc[0]);
   109  
   110              __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]);
   111              __m256 t2 = _mm256_hadd_ps(t1, t1);
   112              __m128 t3 = _mm256_extractf128_ps(t2, 1);
   113              __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3);
   114              sum += _mm_cvtss_f32(t4);
   115  
   116              *res = sum;
   117              return;
   118          }
   119      }
   120  
   121      while (n >= 32)
   122      {
   123          // Unroll loop for 32 floats
   124          __m256 a_vec0 = _mm256_loadu_ps(a);
   125          __m256 a_vec1 = _mm256_loadu_ps(a + 8);
   126          __m256 a_vec2 = _mm256_loadu_ps(a + 16);
   127          __m256 a_vec3 = _mm256_loadu_ps(a + 24);
   128  
   129          __m256 b_vec0 = _mm256_loadu_ps(b);
   130          __m256 b_vec1 = _mm256_loadu_ps(b + 8);
   131          __m256 b_vec2 = _mm256_loadu_ps(b + 16);
   132          __m256 b_vec3 = _mm256_loadu_ps(b + 24);
   133  
   134          acc[0] = _mm256_fmadd_ps(a_vec0, b_vec0, acc[0]);
   135          acc[1] = _mm256_fmadd_ps(a_vec1, b_vec1, acc[1]);
   136          acc[2] = _mm256_fmadd_ps(a_vec2, b_vec2, acc[2]);
   137          acc[3] = _mm256_fmadd_ps(a_vec3, b_vec3, acc[3]);
   138  
   139          n -= 32;
   140          a += 32;
   141          b += 32;
   142      }
   143  
   144      // Process 8 floats at a time
   145      while (n >= 8)
   146      {
   147          __m256 a_vec0 = _mm256_loadu_ps(a);
   148          __m256 b_vec0 = _mm256_loadu_ps(b);
   149  
   150          acc[0] = _mm256_fmadd_ps(a_vec0, b_vec0, acc[0]);
   151  
   152          n -= 8;
   153          a += 8;
   154          b += 8;
   155      }
   156  
   157      // Tail
   158      while (n)
   159      {
   160          sum += a[0] * b[0];
   161          n--;
   162          a++;
   163          b++;
   164      }
   165  
   166      // Reduce and store the result
   167      acc[0] = _mm256_add_ps(acc[1], acc[0]);
   168      acc[2] = _mm256_add_ps(acc[3], acc[2]);
   169      acc[0] = _mm256_add_ps(acc[2], acc[0]);
   170      __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]);
   171      __m256 t2 = _mm256_hadd_ps(t1, t1);
   172      __m128 t3 = _mm256_extractf128_ps(t2, 1);
   173      __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3);
   174      sum += _mm_cvtss_f32(t4);
   175  
   176      *res = sum;
   177  }