github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/dot_avx512_amd64.c (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2023 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 #include <immintrin.h> 13 14 void dot_512(float *a, float *b, float *res, long *len) 15 { 16 int n = *len; 17 float sum = 0; 18 19 // fast path for small dimensions 20 if (n < 8) 21 { 22 do 23 { 24 sum += a[0] * b[0]; 25 n--; 26 a++; 27 b++; 28 } while (n); 29 30 *res = sum; 31 return; 32 } 33 34 // Create 4 registers to store the results 35 __m256 acc[4]; 36 acc[0] = _mm256_setzero_ps(); 37 acc[1] = _mm256_setzero_ps(); 38 acc[2] = _mm256_setzero_ps(); 39 acc[3] = _mm256_setzero_ps(); 40 41 if (n >= 128) 42 { 43 // create 8 registers 44 __m512 acc5[8]; 45 acc5[0] = _mm512_setzero_ps(); 46 acc5[1] = _mm512_setzero_ps(); 47 acc5[2] = _mm512_setzero_ps(); 48 acc5[3] = _mm512_setzero_ps(); 49 acc5[4] = _mm512_setzero_ps(); 50 acc5[5] = _mm512_setzero_ps(); 51 acc5[6] = _mm512_setzero_ps(); 52 acc5[7] = _mm512_setzero_ps(); 53 54 // Process 128 floats at a time 55 do 56 { 57 __m512 a_vec0 = _mm512_loadu_ps(a); 58 __m512 a_vec1 = _mm512_loadu_ps(a + 16); 59 __m512 a_vec2 = _mm512_loadu_ps(a + 32); 60 __m512 a_vec3 = _mm512_loadu_ps(a + 48); 61 __m512 a_vec4 = _mm512_loadu_ps(a + 64); 62 __m512 a_vec5 = _mm512_loadu_ps(a + 80); 63 __m512 a_vec6 = _mm512_loadu_ps(a + 96); 64 __m512 a_vec7 = _mm512_loadu_ps(a + 112); 65 66 __m512 b_vec0 = _mm512_loadu_ps(b); 67 __m512 b_vec1 = _mm512_loadu_ps(b + 16); 68 __m512 b_vec2 = _mm512_loadu_ps(b + 32); 69 __m512 b_vec3 = _mm512_loadu_ps(b + 48); 70 __m512 b_vec4 = _mm512_loadu_ps(b + 64); 71 __m512 b_vec5 = _mm512_loadu_ps(b + 80); 72 __m512 b_vec6 = _mm512_loadu_ps(b + 96); 73 __m512 b_vec7 = _mm512_loadu_ps(b + 112); 74 75 acc5[0] = _mm512_fmadd_ps(a_vec0, b_vec0, acc5[0]); 76 acc5[1] = _mm512_fmadd_ps(a_vec1, b_vec1, acc5[1]); 77 acc5[2] = _mm512_fmadd_ps(a_vec2, b_vec2, acc5[2]); 78 acc5[3] = _mm512_fmadd_ps(a_vec3, b_vec3, acc5[3]); 79 acc5[4] = _mm512_fmadd_ps(a_vec4, b_vec4, acc5[4]); 80 acc5[5] = _mm512_fmadd_ps(a_vec5, b_vec5, acc5[5]); 81 acc5[6] = _mm512_fmadd_ps(a_vec6, b_vec6, acc5[6]); 82 acc5[7] = _mm512_fmadd_ps(a_vec7, b_vec7, acc5[7]); 83 84 n -= 128; 85 a += 128; 86 b += 128; 87 } while (n >= 128); 88 89 acc5[0] = _mm512_add_ps(acc5[1], acc5[0]); 90 acc5[2] = _mm512_add_ps(acc5[3], acc5[2]); 91 acc5[4] = _mm512_add_ps(acc5[5], acc5[4]); 92 acc5[6] = _mm512_add_ps(acc5[7], acc5[6]); 93 acc5[0] = _mm512_add_ps(acc5[2], acc5[0]); 94 acc5[4] = _mm512_add_ps(acc5[6], acc5[4]); 95 acc5[0] = _mm512_add_ps(acc5[4], acc5[0]); 96 97 __m256 low = _mm512_castps512_ps256(acc5[0]); 98 __m256 high = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(acc5[0]), 1)); 99 100 acc[0] = _mm256_add_ps(low, acc[0]); 101 acc[0] = _mm256_add_ps(high, acc[0]); 102 103 if (!n) 104 { 105 // Reduce and store the result 106 acc[0] = _mm256_add_ps(acc[1], acc[0]); 107 acc[2] = _mm256_add_ps(acc[3], acc[2]); 108 acc[0] = _mm256_add_ps(acc[2], acc[0]); 109 110 __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]); 111 __m256 t2 = _mm256_hadd_ps(t1, t1); 112 __m128 t3 = _mm256_extractf128_ps(t2, 1); 113 __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3); 114 sum += _mm_cvtss_f32(t4); 115 116 *res = sum; 117 return; 118 } 119 } 120 121 while (n >= 32) 122 { 123 // Unroll loop for 32 floats 124 __m256 a_vec0 = _mm256_loadu_ps(a); 125 __m256 a_vec1 = _mm256_loadu_ps(a + 8); 126 __m256 a_vec2 = _mm256_loadu_ps(a + 16); 127 __m256 a_vec3 = _mm256_loadu_ps(a + 24); 128 129 __m256 b_vec0 = _mm256_loadu_ps(b); 130 __m256 b_vec1 = _mm256_loadu_ps(b + 8); 131 __m256 b_vec2 = _mm256_loadu_ps(b + 16); 132 __m256 b_vec3 = _mm256_loadu_ps(b + 24); 133 134 acc[0] = _mm256_fmadd_ps(a_vec0, b_vec0, acc[0]); 135 acc[1] = _mm256_fmadd_ps(a_vec1, b_vec1, acc[1]); 136 acc[2] = _mm256_fmadd_ps(a_vec2, b_vec2, acc[2]); 137 acc[3] = _mm256_fmadd_ps(a_vec3, b_vec3, acc[3]); 138 139 n -= 32; 140 a += 32; 141 b += 32; 142 } 143 144 // Process 8 floats at a time 145 while (n >= 8) 146 { 147 __m256 a_vec0 = _mm256_loadu_ps(a); 148 __m256 b_vec0 = _mm256_loadu_ps(b); 149 150 acc[0] = _mm256_fmadd_ps(a_vec0, b_vec0, acc[0]); 151 152 n -= 8; 153 a += 8; 154 b += 8; 155 } 156 157 // Tail 158 while (n) 159 { 160 sum += a[0] * b[0]; 161 n--; 162 a++; 163 b++; 164 } 165 166 // Reduce and store the result 167 acc[0] = _mm256_add_ps(acc[1], acc[0]); 168 acc[2] = _mm256_add_ps(acc[3], acc[2]); 169 acc[0] = _mm256_add_ps(acc[2], acc[0]); 170 __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]); 171 __m256 t2 = _mm256_hadd_ps(t1, t1); 172 __m128 t3 = _mm256_extractf128_ps(t2, 1); 173 __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3); 174 sum += _mm_cvtss_f32(t4); 175 176 *res = sum; 177 }