github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/l2_avx256_amd64.c (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2023 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 #include <immintrin.h> 13 14 void l2_256(float *a, float *b, float *res, long *len) 15 { 16 int n = *len; 17 float sum = 0; 18 19 // fast path for small dimensions 20 if (n < 8) 21 { 22 do 23 { 24 float diff = a[0] - b[0]; 25 float sq = diff * diff; 26 sum += sq; 27 n--; 28 a++; 29 b++; 30 } while (n); 31 32 *res = sum; 33 return; 34 } 35 36 // Create 4 registers to store the results 37 __m256 acc[4]; 38 acc[0] = _mm256_setzero_ps(); 39 acc[1] = _mm256_setzero_ps(); 40 acc[2] = _mm256_setzero_ps(); 41 acc[3] = _mm256_setzero_ps(); 42 43 while (n >= 32) 44 { 45 // Unroll loop for 32 floats 46 __m256 a_vec0 = _mm256_loadu_ps(a); 47 __m256 a_vec1 = _mm256_loadu_ps(a + 8); 48 __m256 a_vec2 = _mm256_loadu_ps(a + 16); 49 __m256 a_vec3 = _mm256_loadu_ps(a + 24); 50 51 __m256 b_vec0 = _mm256_loadu_ps(b); 52 __m256 b_vec1 = _mm256_loadu_ps(b + 8); 53 __m256 b_vec2 = _mm256_loadu_ps(b + 16); 54 __m256 b_vec3 = _mm256_loadu_ps(b + 24); 55 56 __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0); 57 __m256 diff1 = _mm256_sub_ps(a_vec1, b_vec1); 58 __m256 diff2 = _mm256_sub_ps(a_vec2, b_vec2); 59 __m256 diff3 = _mm256_sub_ps(a_vec3, b_vec3); 60 61 acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]); 62 acc[1] = _mm256_fmadd_ps(diff1, diff1, acc[1]); 63 acc[2] = _mm256_fmadd_ps(diff2, diff2, acc[2]); 64 acc[3] = _mm256_fmadd_ps(diff3, diff3, acc[3]); 65 66 n -= 32; 67 a += 32; 68 b += 32; 69 } 70 71 // Process 8 floats at a time 72 while (n >= 8) 73 { 74 __m256 a_vec0 = _mm256_loadu_ps(a); 75 __m256 b_vec0 = _mm256_loadu_ps(b); 76 __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0); 77 78 acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]); 79 80 n -= 8; 81 a += 8; 82 b += 8; 83 } 84 85 // Tail 86 while (n) 87 { 88 float diff = a[0] - b[0]; 89 float sq = diff * diff; 90 sum += sq; 91 n--; 92 a++; 93 b++; 94 } 95 96 // Reduce and store the result 97 acc[0] = _mm256_add_ps(acc[1], acc[0]); 98 acc[2] = _mm256_add_ps(acc[3], acc[2]); 99 acc[0] = _mm256_add_ps(acc[2], acc[0]); 100 __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]); 101 __m256 t2 = _mm256_hadd_ps(t1, t1); 102 __m128 t3 = _mm256_extractf128_ps(t2, 1); 103 __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3); 104 sum += _mm_cvtss_f32(t4); 105 106 *res = sum; 107 }