github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/l2_avx256_amd64.c (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2023 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  #include <immintrin.h>
    13  
    14  void l2_256(float *a, float *b, float *res, long *len)
    15  {
    16      int n = *len;
    17      float sum = 0;
    18  
    19      // fast path for small dimensions
    20      if (n < 8)
    21      {
    22          do
    23          {
    24              float diff = a[0] - b[0];
    25              float sq = diff * diff;
    26              sum += sq;
    27              n--;
    28              a++;
    29              b++;
    30          } while (n);
    31  
    32          *res = sum;
    33          return;
    34      }
    35  
    36      // Create 4 registers to store the results
    37      __m256 acc[4];
    38      acc[0] = _mm256_setzero_ps();
    39      acc[1] = _mm256_setzero_ps();
    40      acc[2] = _mm256_setzero_ps();
    41      acc[3] = _mm256_setzero_ps();
    42  
    43      while (n >= 32)
    44      {
    45          // Unroll loop for 32 floats
    46          __m256 a_vec0 = _mm256_loadu_ps(a);
    47          __m256 a_vec1 = _mm256_loadu_ps(a + 8);
    48          __m256 a_vec2 = _mm256_loadu_ps(a + 16);
    49          __m256 a_vec3 = _mm256_loadu_ps(a + 24);
    50  
    51          __m256 b_vec0 = _mm256_loadu_ps(b);
    52          __m256 b_vec1 = _mm256_loadu_ps(b + 8);
    53          __m256 b_vec2 = _mm256_loadu_ps(b + 16);
    54          __m256 b_vec3 = _mm256_loadu_ps(b + 24);
    55  
    56          __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0);
    57          __m256 diff1 = _mm256_sub_ps(a_vec1, b_vec1);
    58          __m256 diff2 = _mm256_sub_ps(a_vec2, b_vec2);
    59          __m256 diff3 = _mm256_sub_ps(a_vec3, b_vec3);
    60  
    61          acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]);
    62          acc[1] = _mm256_fmadd_ps(diff1, diff1, acc[1]);
    63          acc[2] = _mm256_fmadd_ps(diff2, diff2, acc[2]);
    64          acc[3] = _mm256_fmadd_ps(diff3, diff3, acc[3]);
    65  
    66          n -= 32;
    67          a += 32;
    68          b += 32;
    69      }
    70  
    71      // Process 8 floats at a time
    72      while (n >= 8)
    73      {
    74          __m256 a_vec0 = _mm256_loadu_ps(a);
    75          __m256 b_vec0 = _mm256_loadu_ps(b);
    76          __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0);
    77  
    78          acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]);
    79  
    80          n -= 8;
    81          a += 8;
    82          b += 8;
    83      }
    84  
    85      // Tail
    86      while (n)
    87      {
    88          float diff = a[0] - b[0];
    89          float sq = diff * diff;
    90          sum += sq;
    91          n--;
    92          a++;
    93          b++;
    94      }
    95  
    96      // Reduce and store the result
    97      acc[0] = _mm256_add_ps(acc[1], acc[0]);
    98      acc[2] = _mm256_add_ps(acc[3], acc[2]);
    99      acc[0] = _mm256_add_ps(acc[2], acc[0]);
   100      __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]);
   101      __m256 t2 = _mm256_hadd_ps(t1, t1);
   102      __m128 t3 = _mm256_extractf128_ps(t2, 1);
   103      __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3);
   104      sum += _mm_cvtss_f32(t4);
   105  
   106      *res = sum;
   107  }