github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/c/dot_avx256_amd64.c (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2023 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 #include <immintrin.h> 13 14 void dot_256(float *a, float *b, float *res, long *len) 15 { 16 int n = *len; 17 float sum = 0; 18 19 // fast path for small dimensions 20 if (n < 8) 21 { 22 do 23 { 24 sum += a[0] * b[0]; 25 n--; 26 a++; 27 b++; 28 } while (n); 29 30 *res = sum; 31 return; 32 } 33 34 // Create 4 registers to store the results 35 __m256 acc[4]; 36 acc[0] = _mm256_setzero_ps(); 37 acc[1] = _mm256_setzero_ps(); 38 acc[2] = _mm256_setzero_ps(); 39 acc[3] = _mm256_setzero_ps(); 40 41 while (n >= 32) 42 { 43 // Unroll loop for 32 floats 44 __m256 a_vec0 = _mm256_loadu_ps(a); 45 __m256 a_vec1 = _mm256_loadu_ps(a + 8); 46 __m256 a_vec2 = _mm256_loadu_ps(a + 16); 47 __m256 a_vec3 = _mm256_loadu_ps(a + 24); 48 49 __m256 b_vec0 = _mm256_loadu_ps(b); 50 __m256 b_vec1 = _mm256_loadu_ps(b + 8); 51 __m256 b_vec2 = _mm256_loadu_ps(b + 16); 52 __m256 b_vec3 = _mm256_loadu_ps(b + 24); 53 54 acc[0] = _mm256_fmadd_ps(a_vec0, b_vec0, acc[0]); 55 acc[1] = _mm256_fmadd_ps(a_vec1, b_vec1, acc[1]); 56 acc[2] = _mm256_fmadd_ps(a_vec2, b_vec2, acc[2]); 57 acc[3] = _mm256_fmadd_ps(a_vec3, b_vec3, acc[3]); 58 59 n -= 32; 60 a += 32; 61 b += 32; 62 } 63 64 // Process 8 floats at a time 65 while (n >= 8) 66 { 67 __m256 a_vec0 = _mm256_loadu_ps(a); 68 __m256 b_vec0 = _mm256_loadu_ps(b); 69 70 acc[0] = _mm256_fmadd_ps(a_vec0, b_vec0, acc[0]); 71 72 n -= 8; 73 a += 8; 74 b += 8; 75 } 76 77 // Tail 78 while (n) 79 { 80 sum += a[0] * b[0]; 81 n--; 82 a++; 83 b++; 84 } 85 86 // Reduce and store the result 87 acc[0] = _mm256_add_ps(acc[1], acc[0]); 88 acc[2] = _mm256_add_ps(acc[3], acc[2]); 89 acc[0] = _mm256_add_ps(acc[2], acc[0]); 90 __m256 t1 = _mm256_hadd_ps(acc[0], acc[0]); 91 __m256 t2 = _mm256_hadd_ps(t1, t1); 92 __m128 t3 = _mm256_extractf128_ps(t2, 1); 93 __m128 t4 = _mm_add_ps(_mm256_castps256_ps128(t2), t3); 94 sum += _mm_cvtss_f32(t4); 95 96 *res = sum; 97 }