github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/asm/dot.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build ignore
    13  // +build ignore
    14  
    15  package main
    16  
    17  import (
    18  	. "github.com/mmcloughlin/avo/build"
    19  	. "github.com/mmcloughlin/avo/operand"
    20  	. "github.com/mmcloughlin/avo/reg"
    21  )
    22  
    23  var unroll = 4
    24  
    25  // inspired by the avo example which is itself inspired by the PeachPy example.
    26  // Adjusted unrolling that fits our use cases better.
    27  func main() {
    28  	TEXT("Dot", NOSPLIT, "func(x, y []float32) float32")
    29  	x := Mem{Base: Load(Param("x").Base(), GP64())}
    30  	y := Mem{Base: Load(Param("y").Base(), GP64())}
    31  	n := Load(Param("x").Len(), GP64())
    32  
    33  	acc := make([]VecVirtual, unroll)
    34  	for i := 0; i < unroll; i++ {
    35  		acc[i] = YMM()
    36  	}
    37  
    38  	for i := 0; i < unroll; i++ {
    39  		VXORPS(acc[i], acc[i], acc[i])
    40  	}
    41  
    42  	blockitems := 8 * unroll
    43  	blocksize := 4 * blockitems
    44  	Label("blockloop")
    45  	CMPQ(n, U32(blockitems))
    46  	JL(LabelRef("tail"))
    47  
    48  	// Load x.
    49  	xs := make([]VecVirtual, unroll)
    50  	for i := 0; i < unroll; i++ {
    51  		xs[i] = YMM()
    52  	}
    53  
    54  	for i := 0; i < unroll; i++ {
    55  		VMOVUPS(x.Offset(32*i), xs[i])
    56  	}
    57  
    58  	// The actual FMA.
    59  	for i := 0; i < unroll; i++ {
    60  		VFMADD231PS(y.Offset(32*i), xs[i], acc[i])
    61  	}
    62  
    63  	ADDQ(U32(blocksize), x.Base)
    64  	ADDQ(U32(blocksize), y.Base)
    65  	SUBQ(U32(blockitems), n)
    66  	JMP(LabelRef("blockloop"))
    67  
    68  	// Process any trailing entries.
    69  	Label("tail")
    70  	tail := XMM()
    71  	VXORPS(tail, tail, tail)
    72  
    73  	Label("tailloop")
    74  	CMPQ(n, U32(0))
    75  	JE(LabelRef("reduce"))
    76  
    77  	xt := XMM()
    78  	VMOVSS(x, xt)
    79  	VFMADD231SS(y, xt, tail)
    80  
    81  	ADDQ(U32(4), x.Base)
    82  	ADDQ(U32(4), y.Base)
    83  	DECQ(n)
    84  	JMP(LabelRef("tailloop"))
    85  
    86  	// Reduce the lanes to one.
    87  	Label("reduce")
    88  	if unroll != 4 {
    89  		// we have hard-coded the reduction for this specific unrolling as it
    90  		// allows us to do 0+1 and 2+3 and only then have a multiplication which
    91  		// touches both.
    92  		panic("addition is hard-coded")
    93  	}
    94  
    95  	// Manual reduction
    96  	VADDPS(acc[0], acc[1], acc[0])
    97  	VADDPS(acc[2], acc[3], acc[2])
    98  	VADDPS(acc[0], acc[2], acc[0])
    99  
   100  	result := acc[0].AsX()
   101  	top := XMM()
   102  	VEXTRACTF128(U8(1), acc[0], top)
   103  	VADDPS(result, top, result)
   104  	VADDPS(result, tail, result)
   105  	VHADDPS(result, result, result)
   106  	VHADDPS(result, result, result)
   107  	Store(result, ReturnIndex(0))
   108  
   109  	RET()
   110  
   111  	Generate()
   112  }