github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/l2_avx512_amd64.c (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2023 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 #include <immintrin.h> 13 14 void l2_512(float *a, float *b, float *res, long *len) 15 { 16 int n = *len; 17 float sum = 0; 18 19 // fast path for small dimensions 20 if (n < 8) 21 { 22 do 23 { 24 float diff = a[0] - b[0]; 25 float sq = diff * diff; 26 sum += sq; 27 n--; 28 a++; 29 b++; 30 } while (n); 31 32 *res = sum; 33 return; 34 } 35 36 // Create 4 registers to store the results 37 __m256 acc[4]; 38 acc[0] = _mm256_setzero_ps(); 39 acc[1] = _mm256_setzero_ps(); 40 acc[2] = _mm256_setzero_ps(); 41 acc[3] = _mm256_setzero_ps(); 42 43 if (n >= 128) 44 { 45 // create 8 registers 46 __m512 acc5[8]; 47 acc5[0] = _mm512_setzero_ps(); 48 acc5[1] = _mm512_setzero_ps(); 49 acc5[2] = _mm512_setzero_ps(); 50 acc5[3] = _mm512_setzero_ps(); 51 acc5[4] = _mm512_setzero_ps(); 52 acc5[5] = _mm512_setzero_ps(); 53 acc5[6] = _mm512_setzero_ps(); 54 acc5[7] = _mm512_setzero_ps(); 55 56 // Process 128 floats at a time 57 do 58 { 59 __m512 a_vec0 = _mm512_loadu_ps(a); 60 __m512 a_vec1 = _mm512_loadu_ps(a + 16); 61 __m512 a_vec2 = _mm512_loadu_ps(a + 32); 62 __m512 a_vec3 = _mm512_loadu_ps(a + 48); 63 __m512 a_vec4 = _mm512_loadu_ps(a + 64); 64 __m512 a_vec5 = _mm512_loadu_ps(a + 80); 65 __m512 a_vec6 = _mm512_loadu_ps(a + 96); 66 __m512 a_vec7 = _mm512_loadu_ps(a + 112); 67 68 __m512 b_vec0 = _mm512_loadu_ps(b); 69 __m512 b_vec1 = _mm512_loadu_ps(b + 16); 70 __m512 b_vec2 = _mm512_loadu_ps(b + 32); 71 __m512 b_vec3 = _mm512_loadu_ps(b + 48); 72 __m512 b_vec4 = _mm512_loadu_ps(b + 64); 73 __m512 b_vec5 = _mm512_loadu_ps(b + 80); 74 __m512 b_vec6 = _mm512_loadu_ps(b + 96); 75 __m512 b_vec7 = _mm512_loadu_ps(b + 112); 76 77 __m512 diff0 = _mm512_sub_ps(a_vec0, b_vec0); 78 __m512 diff1 = _mm512_sub_ps(a_vec1, b_vec1); 79 __m512 diff2 = _mm512_sub_ps(a_vec2, b_vec2); 80 __m512 diff3 = _mm512_sub_ps(a_vec3, b_vec3); 81 __m512 diff4 = _mm512_sub_ps(a_vec4, b_vec4); 82 __m512 diff5 = _mm512_sub_ps(a_vec5, b_vec5); 83 __m512 diff6 = _mm512_sub_ps(a_vec6, b_vec6); 84 __m512 diff7 = _mm512_sub_ps(a_vec7, b_vec7); 85 86 acc5[0] = _mm512_fmadd_ps(diff0, diff0, acc5[0]); 87 acc5[1] = _mm512_fmadd_ps(diff1, diff1, acc5[1]); 88 acc5[2] = _mm512_fmadd_ps(diff2, diff2, acc5[2]); 89 acc5[3] = _mm512_fmadd_ps(diff3, diff3, acc5[3]); 90 acc5[4] = _mm512_fmadd_ps(diff4, diff4, acc5[4]); 91 acc5[5] = _mm512_fmadd_ps(diff5, diff5, acc5[5]); 92 acc5[6] = _mm512_fmadd_ps(diff6, diff6, acc5[6]); 93 acc5[7] = _mm512_fmadd_ps(diff7, diff7, acc5[7]); 94 95 n -= 128; 96 a += 128; 97 b += 128; 98 } while (n >= 128); 99 100 acc5[0] = _mm512_add_ps(acc5[1], acc5[0]); 101 acc5[2] = _mm512_add_ps(acc5[3], acc5[2]); 102 acc5[4] = _mm512_add_ps(acc5[5], acc5[4]); 103 acc5[6] = _mm512_add_ps(acc5[7], acc5[6]); 104 acc5[0] = _mm512_add_ps(acc5[2], acc5[0]); 105 acc5[4] = _mm512_add_ps(acc5[6], acc5[4]); 106 acc5[0] = _mm512_add_ps(acc5[4], acc5[0]); 107 108 __m256 low = _mm512_castps512_ps256(acc5[0]); 109 __m256 high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(acc5[0]), 1)); 110 111 acc[0] = _mm256_add_ps(low, acc[0]); 112 acc[0] = _mm256_add_ps(high, acc[0]); 113 114 if (!n) 115 { 116 // Reduce and store the result 117 acc[0] = _mm256_add_ps(acc[1], acc[0]); 118 acc[2] = _mm256_add_ps(acc[3], acc[2]); 119 acc[0] = _mm256_add_ps(acc[2], acc[0]); 120 121 __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]); 122 __m256 t2 = _mm256_hadd_ps(t1, t1); 123 __m128 t3 = _mm256_extractf128_ps(t2, 1); 124 __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3); 125 sum += _mm_cvtss_f32(t4); 126 127 *res = sum; 128 return; 129 } 130 } 131 132 while (n >= 32) 133 { 134 // Unroll loop for 32 floats 135 __m256 a_vec0 = _mm256_loadu_ps(a); 136 __m256 a_vec1 = _mm256_loadu_ps(a + 8); 137 __m256 a_vec2 = _mm256_loadu_ps(a + 16); 138 __m256 a_vec3 = _mm256_loadu_ps(a + 24); 139 140 __m256 b_vec0 = _mm256_loadu_ps(b); 141 __m256 b_vec1 = _mm256_loadu_ps(b + 8); 142 __m256 b_vec2 = _mm256_loadu_ps(b + 16); 143 __m256 b_vec3 = _mm256_loadu_ps(b + 24); 144 145 __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0); 146 __m256 diff1 = _mm256_sub_ps(a_vec1, b_vec1); 147 __m256 diff2 = _mm256_sub_ps(a_vec2, b_vec2); 148 __m256 diff3 = _mm256_sub_ps(a_vec3, b_vec3); 149 150 acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]); 151 acc[1] = _mm256_fmadd_ps(diff1, diff1, acc[1]); 152 acc[2] = _mm256_fmadd_ps(diff2, diff2, acc[2]); 153 acc[3] = _mm256_fmadd_ps(diff3, diff3, acc[3]); 154 155 n -= 32; 156 a += 32; 157 b += 32; 158 } 159 160 // Process 8 floats at a time 161 while (n >= 8) 162 { 163 __m256 a_vec0 = _mm256_loadu_ps(a); 164 __m256 b_vec0 = _mm256_loadu_ps(b); 165 __m256 diff0 = _mm256_sub_ps(a_vec0, b_vec0); 166 167 acc[0] = _mm256_fmadd_ps(diff0, diff0, acc[0]); 168 169 n -= 8; 170 a += 8; 171 b += 8; 172 } 173 174 // Tail 175 while (n) 176 { 177 float diff = a[0] - b[0]; 178 float sq = diff * diff; 179 sum += sq; 180 n--; 181 a++; 182 b++; 183 } 184 185 // Reduce and store the result 186 acc[0] = _mm256_add_ps(acc[1], acc[0]); 187 acc[2] = _mm256_add_ps(acc[3], acc[2]); 188 acc[0] = _mm256_add_ps(acc[2], acc[0]); 189 __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]); 190 __m256 t2 = _mm256_hadd_ps(t1, t1); 191 __m128 t3 = _mm256_extractf128_ps(t2, 1); 192 __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3); 193 sum += _mm_cvtss_f32(t4); 194 195 *res = sum; 196 }