github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/asm/l2.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build ignore
    13  // +build ignore
    14  
    15  package main
    16  
    17  import (
    18  	. "github.com/mmcloughlin/avo/build"
    19  	. "github.com/mmcloughlin/avo/operand"
    20  	. "github.com/mmcloughlin/avo/reg"
    21  )
    22  
    23  var unroll = 4
    24  
    25  func main() {
    26  	TEXT("L2", NOSPLIT, "func(x, y []float32) float32")
    27  	x := Mem{Base: Load(Param("x").Base(), GP64())}
    28  	y := Mem{Base: Load(Param("y").Base(), GP64())}
    29  	n := Load(Param("x").Len(), GP64())
    30  
    31  	acc := make([]VecVirtual, unroll)
    32  	diff := make([]VecVirtual, unroll)
    33  	for i := 0; i < unroll; i++ {
    34  		acc[i] = YMM()
    35  		diff[i] = YMM()
    36  	}
    37  
    38  	for i := 0; i < unroll; i++ {
    39  		VXORPS(acc[i], acc[i], acc[i])
    40  		VXORPS(diff[i], diff[i], diff[i])
    41  	}
    42  
    43  	blockitems := 8 * unroll
    44  	blocksize := 4 * blockitems
    45  	Label("blockloop")
    46  	CMPQ(n, U32(blockitems))
    47  	JL(LabelRef("tail"))
    48  
    49  	// Load x.
    50  	xs := make([]VecVirtual, unroll)
    51  	for i := 0; i < unroll; i++ {
    52  		xs[i] = YMM()
    53  	}
    54  
    55  	for i := 0; i < unroll; i++ {
    56  		VMOVUPS(x.Offset(32*i), xs[i])
    57  	}
    58  
    59  	for i := 0; i < unroll; i++ {
    60  		VSUBPS(y.Offset(32*i), xs[i], diff[i])
    61  	}
    62  
    63  	for i := 0; i < unroll; i++ {
    64  		VFMADD231PS(diff[i], diff[i], acc[i])
    65  	}
    66  
    67  	ADDQ(U32(blocksize), x.Base)
    68  	ADDQ(U32(blocksize), y.Base)
    69  	SUBQ(U32(blockitems), n)
    70  	JMP(LabelRef("blockloop"))
    71  
    72  	// Process any trailing entries.
    73  	Label("tail")
    74  	tail := XMM()
    75  	VXORPS(tail, tail, tail)
    76  
    77  	Label("tailloop")
    78  	CMPQ(n, U32(0))
    79  	JE(LabelRef("reduce"))
    80  
    81  	xt := XMM()
    82  	VMOVSS(x, xt)
    83  
    84  	difft := XMM()
    85  	VSUBSS(y, xt, difft)
    86  
    87  	VFMADD231SS(difft, difft, tail)
    88  
    89  	ADDQ(U32(4), x.Base)
    90  	ADDQ(U32(4), y.Base)
    91  	DECQ(n)
    92  	JMP(LabelRef("tailloop"))
    93  
    94  	// Reduce the lanes to one.
    95  	Label("reduce")
    96  	if unroll != 4 {
    97  		// we have hard-coded the reduction for this specific unrolling as it
    98  		// allows us to do 0+1 and 2+3 and only then have a multiplication which
    99  		// touches both.
   100  		panic("addition is hard-coded")
   101  	}
   102  
   103  	// Manual reduction
   104  	VADDPS(acc[0], acc[1], acc[0])
   105  	VADDPS(acc[2], acc[3], acc[2])
   106  	VADDPS(acc[0], acc[2], acc[0])
   107  
   108  	result := acc[0].AsX()
   109  	top := XMM()
   110  	VEXTRACTF128(U8(1), acc[0], top)
   111  	VADDPS(result, top, result)
   112  	VADDPS(result, tail, result)
   113  	VHADDPS(result, result, result)
   114  	VHADDPS(result, result, result)
   115  	Store(result, ReturnIndex(0))
   116  
   117  	RET()
   118  
   119  	Generate()
   120  }