github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/asm/dot.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build ignore 13 // +build ignore 14 15 package main 16 17 import ( 18 . "github.com/mmcloughlin/avo/build" 19 . "github.com/mmcloughlin/avo/operand" 20 . "github.com/mmcloughlin/avo/reg" 21 ) 22 23 var unroll = 4 24 25 // inspired by the avo example which is itself inspired by the PeachPy example. 26 // Adjusted unrolling that fits our use cases better. 27 func main() { 28 TEXT("Dot", NOSPLIT, "func(x, y []float32) float32") 29 x := Mem{Base: Load(Param("x").Base(), GP64())} 30 y := Mem{Base: Load(Param("y").Base(), GP64())} 31 n := Load(Param("x").Len(), GP64()) 32 33 acc := make([]VecVirtual, unroll) 34 for i := 0; i < unroll; i++ { 35 acc[i] = YMM() 36 } 37 38 for i := 0; i < unroll; i++ { 39 VXORPS(acc[i], acc[i], acc[i]) 40 } 41 42 blockitems := 8 * unroll 43 blocksize := 4 * blockitems 44 Label("blockloop") 45 CMPQ(n, U32(blockitems)) 46 JL(LabelRef("tail")) 47 48 // Load x. 49 xs := make([]VecVirtual, unroll) 50 for i := 0; i < unroll; i++ { 51 xs[i] = YMM() 52 } 53 54 for i := 0; i < unroll; i++ { 55 VMOVUPS(x.Offset(32*i), xs[i]) 56 } 57 58 // The actual FMA. 59 for i := 0; i < unroll; i++ { 60 VFMADD231PS(y.Offset(32*i), xs[i], acc[i]) 61 } 62 63 ADDQ(U32(blocksize), x.Base) 64 ADDQ(U32(blocksize), y.Base) 65 SUBQ(U32(blockitems), n) 66 JMP(LabelRef("blockloop")) 67 68 // Process any trailing entries. 69 Label("tail") 70 tail := XMM() 71 VXORPS(tail, tail, tail) 72 73 Label("tailloop") 74 CMPQ(n, U32(0)) 75 JE(LabelRef("reduce")) 76 77 xt := XMM() 78 VMOVSS(x, xt) 79 VFMADD231SS(y, xt, tail) 80 81 ADDQ(U32(4), x.Base) 82 ADDQ(U32(4), y.Base) 83 DECQ(n) 84 JMP(LabelRef("tailloop")) 85 86 // Reduce the lanes to one. 87 Label("reduce") 88 if unroll != 4 { 89 // we have hard-coded the reduction for this specific unrolling as it 90 // allows us to do 0+1 and 2+3 and only then have a multiplication which 91 // touches both. 92 panic("addition is hard-coded") 93 } 94 95 // Manual reduction 96 VADDPS(acc[0], acc[1], acc[0]) 97 VADDPS(acc[2], acc[3], acc[2]) 98 VADDPS(acc[0], acc[2], acc[0]) 99 100 result := acc[0].AsX() 101 top := XMM() 102 VEXTRACTF128(U8(1), acc[0], top) 103 VADDPS(result, top, result) 104 VADDPS(result, tail, result) 105 VHADDPS(result, result, result) 106 VHADDPS(result, result, result) 107 Store(result, ReturnIndex(0)) 108 109 RET() 110 111 Generate() 112 }