github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/l2_avx512_amd64.c (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2023 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  #include <immintrin.h>
    13  
    14  void l2_512(float *a, float *b, float *res, long *len)
    15  {
    16      int n = *len;
    17      float sum = 0;
    18  
    19      // fast path for small dimensions
    20      if (n < 8)
    21      {
    22          do
    23          {
    24              float diff = a[0] - b[0];
    25              float sq = diff * diff;
    26              sum += sq;
    27              n--;
    28              a++;
    29              b++;
    30          } while (n);
    31  
    32          *res = sum;
    33          return;
    34      }
    35  
    36      // Create 4 registers to store the results
    37      __m256 acc[4];
    38      acc[0] = _mm256_setzero_ps();
    39      acc[1] = _mm256_setzero_ps();
    40      acc[2] = _mm256_setzero_ps();
    41      acc[3] = _mm256_setzero_ps();
    42  
    43      if (n >= 128)
    44      {
    45          // create 8 registers
    46          __m512 acc5[8];
    47          acc5[0] = _mm512_setzero_ps();
    48          acc5[1] = _mm512_setzero_ps();
    49          acc5[2] = _mm512_setzero_ps();
    50          acc5[3] = _mm512_setzero_ps();
    51          acc5[4] = _mm512_setzero_ps();
    52          acc5[5] = _mm512_setzero_ps();
    53          acc5[6] = _mm512_setzero_ps();
    54          acc5[7] = _mm512_setzero_ps();
    55  
    56          // Process 128 floats at a time
    57          do
    58          {
    59              __m512 a_vec0 = _mm512_loadu_ps(a);
    60              __m512 a_vec1 = _mm512_loadu_ps(a + 16);
    61              __m512 a_vec2 = _mm512_loadu_ps(a + 32);
    62              __m512 a_vec3 = _mm512_loadu_ps(a + 48);
    63              __m512 a_vec4 = _mm512_loadu_ps(a + 64);
    64              __m512 a_vec5 = _mm512_loadu_ps(a + 80);
    65              __m512 a_vec6 = _mm512_loadu_ps(a + 96);
    66              __m512 a_vec7 = _mm512_loadu_ps(a + 112);
    67  
    68              __m512 b_vec0 = _mm512_loadu_ps(b);
    69              __m512 b_vec1 = _mm512_loadu_ps(b + 16);
    70              __m512 b_vec2 = _mm512_loadu_ps(b + 32);
    71              __m512 b_vec3 = _mm512_loadu_ps(b + 48);
    72              __m512 b_vec4 = _mm512_loadu_ps(b + 64);
    73              __m512 b_vec5 = _mm512_loadu_ps(b + 80);
    74              __m512 b_vec6 = _mm512_loadu_ps(b + 96);
    75              __m512 b_vec7 = _mm512_loadu_ps(b + 112);
    76  
    77              __m512 diff0 = _mm512_sub_ps(a_vec0, b_vec0);
    78              __m512 diff1 = _mm512_sub_ps(a_vec1, b_vec1);
    79              __m512 diff2 = _mm512_sub_ps(a_vec2, b_vec2);
    80              __m512 diff3 = _mm512_sub_ps(a_vec3, b_vec3);
    81              __m512 diff4 = _mm512_sub_ps(a_vec4, b_vec4);
    82              __m512 diff5 = _mm512_sub_ps(a_vec5, b_vec5);
    83              __m512 diff6 = _mm512_sub_ps(a_vec6, b_vec6);
    84              __m512 diff7 = _mm512_sub_ps(a_vec7, b_vec7);
    85  
    86              acc5[0] = _mm512_fmadd_ps(diff0, diff0, acc5[0]);
    87              acc5[1] = _mm512_fmadd_ps(diff1, diff1, acc5[1]);
    88              acc5[2] = _mm512_fmadd_ps(diff2, diff2, acc5[2]);
    89              acc5[3] = _mm512_fmadd_ps(diff3, diff3, acc5[3]);
    90              acc5[4] = _mm512_fmadd_ps(diff4, diff4, acc5[4]);
    91              acc5[5] = _mm512_fmadd_ps(diff5, diff5, acc5[5]);
    92              acc5[6] = _mm512_fmadd_ps(diff6, diff6, acc5[6]);
    93              acc5[7] = _mm512_fmadd_ps(diff7, diff7, acc5[7]);
    94  
    95              n -= 128;
    96              a += 128;
    97              b += 128;
    98          } while (n >= 128);
    99  
   100          acc5[0] = _mm512_add_ps(acc5[1], acc5[0]);
   101          acc5[2] = _mm512_add_ps(acc5[3], acc5[2]);
   102          acc5[4] = _mm512_add_ps(acc5[5], acc5[4]);
   103          acc5[6] = _mm512_add_ps(acc5[7], acc5[6]);
   104          acc5[0] = _mm512_add_ps(acc5[2], acc5[0]);
   105          acc5[4] = _mm512_add_ps(acc5[6], acc5[4]);
   106          acc5[0] = _mm512_add_ps(acc5[4], acc5[0]);
   107  
   108          __m256 low = _mm512_castps512_ps256(acc5[0]);
   109          __m256 high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(acc5[0]), 1));
   110  
   111          acc[0] = _mm256_add_ps(low, acc[0]);
   112          acc[0] = _mm256_add_ps(high, acc[0]);
   113  
   114          if (!n)
   115          {
   116              // Reduce and store the result
   117              acc[0] = _mm256_add_ps(acc[1], acc[0]);
   118              acc[2] = _mm256_add_ps(acc[3], acc[2]);
   119              acc[0] = _mm256_add_ps(acc[2], acc[0]);
   120  
   121              __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]);
   122              __m256 t2 = _mm256_hadd_ps(t1, t1);
   123              __m128 t3 = _mm256_extractf128_ps(t2, 1);
   124              __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3);
   125              sum += _mm_cvtss_f32(t4);
   126  
   127              *res = sum;
   128              return;
   129          }
   130      }
   131  
   132      while (n >= 32)
   133      {
   134          // Unroll loop for 32 floats
   135          __m256 a_vec0 = _mm256_loadu_ps(a);
   136          __m256 a_vec1 = _mm256_loadu_ps(a + 8);
   137          __m256 a_vec2 = _mm256_loadu_ps(a + 16);
   138          __m256 a_vec3 = _mm256_loadu_ps(a + 24);
   139  
   140          __m256 b_vec0 = _mm256_loadu_ps(b);
   141          __m256 b_vec1 = _mm256_loadu_ps(b + 8);
   142          __m256 b_vec2 = _mm256_loadu_ps(b + 16);
   143          __m256 b_vec3 = _mm256_loadu_ps(b + 24);
   144  
   145          __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0);
   146          __m256 diff1 = _mm256_sub_ps(a_vec1, b_vec1);
   147          __m256 diff2 = _mm256_sub_ps(a_vec2, b_vec2);
   148          __m256 diff3 = _mm256_sub_ps(a_vec3, b_vec3);
   149  
   150          acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]);
   151          acc[1] = _mm256_fmadd_ps(diff1, diff1, acc[1]);
   152          acc[2] = _mm256_fmadd_ps(diff2, diff2, acc[2]);
   153          acc[3] = _mm256_fmadd_ps(diff3, diff3, acc[3]);
   154  
   155          n -= 32;
   156          a += 32;
   157          b += 32;
   158      }
   159  
   160      // Process 8 floats at a time
   161      while (n >= 8)
   162      {
   163          __m256 a_vec0 = _mm256_loadu_ps(a);
   164          __m256 b_vec0 = _mm256_loadu_ps(b);
   165          __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0);
   166  
   167          acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]);
   168  
   169          n -= 8;
   170          a += 8;
   171          b += 8;
   172      }
   173  
   174      // Tail
   175      while (n)
   176      {
   177          float diff = a[0] - b[0];
   178          float sq = diff * diff;
   179          sum += sq;
   180          n--;
   181          a++;
   182          b++;
   183      }
   184  
   185      // Reduce and store the result
   186      acc[0] = _mm256_add_ps(acc[1], acc[0]);
   187      acc[2] = _mm256_add_ps(acc[3], acc[2]);
   188      acc[0] = _mm256_add_ps(acc[2], acc[0]);
   189      __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]);
   190      __m256 t2 = _mm256_hadd_ps(t1, t1);
   191      __m128 t3 = _mm256_extractf128_ps(t2, 1);
   192      __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3);
   193      sum += _mm_cvtss_f32(t4);
   194  
   195      *res = sum;
   196  }