github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/asm/l2.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build ignore 13 // +build ignore 14 15 package main 16 17 import ( 18 . "github.com/mmcloughlin/avo/build" 19 . "github.com/mmcloughlin/avo/operand" 20 . "github.com/mmcloughlin/avo/reg" 21 ) 22 23 var unroll = 4 24 25 func main() { 26 TEXT("L2", NOSPLIT, "func(x, y []float32) float32") 27 x := Mem{Base: Load(Param("x").Base(), GP64())} 28 y := Mem{Base: Load(Param("y").Base(), GP64())} 29 n := Load(Param("x").Len(), GP64()) 30 31 acc := make([]VecVirtual, unroll) 32 diff := make([]VecVirtual, unroll) 33 for i := 0; i < unroll; i++ { 34 acc[i] = YMM() 35 diff[i] = YMM() 36 } 37 38 for i := 0; i < unroll; i++ { 39 VXORPS(acc[i], acc[i], acc[i]) 40 VXORPS(diff[i], diff[i], diff[i]) 41 } 42 43 blockitems := 8 * unroll 44 blocksize := 4 * blockitems 45 Label("blockloop") 46 CMPQ(n, U32(blockitems)) 47 JL(LabelRef("tail")) 48 49 // Load x. 50 xs := make([]VecVirtual, unroll) 51 for i := 0; i < unroll; i++ { 52 xs[i] = YMM() 53 } 54 55 for i := 0; i < unroll; i++ { 56 VMOVUPS(x.Offset(32*i), xs[i]) 57 } 58 59 for i := 0; i < unroll; i++ { 60 VSUBPS(y.Offset(32*i), xs[i], diff[i]) 61 } 62 63 for i := 0; i < unroll; i++ { 64 VFMADD231PS(diff[i], diff[i], acc[i]) 65 } 66 67 ADDQ(U32(blocksize), x.Base) 68 ADDQ(U32(blocksize), y.Base) 69 SUBQ(U32(blockitems), n) 70 JMP(LabelRef("blockloop")) 71 72 // Process any trailing entries. 73 Label("tail") 74 tail := XMM() 75 VXORPS(tail, tail, tail) 76 77 Label("tailloop") 78 CMPQ(n, U32(0)) 79 JE(LabelRef("reduce")) 80 81 xt := XMM() 82 VMOVSS(x, xt) 83 84 difft := XMM() 85 VSUBSS(y, xt, difft) 86 87 VFMADD231SS(difft, difft, tail) 88 89 ADDQ(U32(4), x.Base) 90 ADDQ(U32(4), y.Base) 91 DECQ(n) 92 JMP(LabelRef("tailloop")) 93 94 // Reduce the lanes to one. 95 Label("reduce") 96 if unroll != 4 { 97 // we have hard-coded the reduction for this specific unrolling as it 98 // allows us to do 0+1 and 2+3 and only then have a multiplication which 99 // touches both. 100 panic("addition is hard-coded") 101 } 102 103 // Manual reduction 104 VADDPS(acc[0], acc[1], acc[0]) 105 VADDPS(acc[2], acc[3], acc[2]) 106 VADDPS(acc[0], acc[2], acc[0]) 107 108 result := acc[0].AsX() 109 top := XMM() 110 VEXTRACTF128(U8(1), acc[0], top) 111 VADDPS(result, top, result) 112 VADDPS(result, tail, result) 113 VHADDPS(result, result, result) 114 VHADDPS(result, result, result) 115 Store(result, ReturnIndex(0)) 116 117 RET() 118 119 Generate() 120 }