github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/distancer/asm/l2_stub_arm64.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package asm
    13  
    14  import (
    15  	"reflect"
    16  	"unsafe"
    17  )
    18  
    19  // To generate the asm code, run:
    20  //   go install github.com/gorse-io/goat@v0.1.0
    21  //   go generate
    22  
    23  //go:generate goat ../c/l2_arm64.c -O3 -e="-mfpu=neon-fp-armv8" -e="-mfloat-abi=hard" -e="--target=arm64" -e="-march=armv8-a+simd+fp"
    24  
    25  // L2 calculates the L2 distance between two vectors
    26  // using SIMD instructions when possible.
    27  // Vector lengths < 16 are handled by the Go implementation
    28  // because the overhead of using reflection is too high.
    29  func L2(x []float32, y []float32) float32 {
    30  	switch len(x) {
    31  	case 2:
    32  		return l22(x, y)
    33  	case 4:
    34  		return l24(x, y)
    35  	case 6:
    36  		// manually inlined l26(x, y)
    37  		diff := x[5] - y[5]
    38  		sum := diff * diff
    39  
    40  		diff = x[4] - y[4]
    41  		sum += diff * diff
    42  
    43  		return l24(x, y) + sum
    44  	case 8:
    45  		// manually inlined l28(x, y)
    46  		diff := x[7] - y[7]
    47  		sum := diff * diff
    48  
    49  		diff = x[6] - y[6]
    50  		sum += diff * diff
    51  
    52  		diff = x[5] - y[5]
    53  		sum += diff * diff
    54  
    55  		diff = x[4] - y[4]
    56  		sum += diff * diff
    57  
    58  		return l24(x, y) + sum
    59  	case 10:
    60  		return l210(x, y)
    61  	case 12:
    62  		return l212(x, y)
    63  	}
    64  
    65  	// deal with odd lengths and lengths 13, 14, 15
    66  	if len(x) < 16 {
    67  		var sum float32
    68  
    69  		for i := range x {
    70  			diff := x[i] - y[i]
    71  			sum += diff * diff
    72  		}
    73  
    74  		return sum
    75  	}
    76  
    77  	var res float32
    78  
    79  	// The C function expects pointers to the underlying array, not slices.
    80  	hdrx := (*reflect.SliceHeader)(unsafe.Pointer(&x))
    81  	hdry := (*reflect.SliceHeader)(unsafe.Pointer(&y))
    82  
    83  	l := len(x)
    84  	l2(
    85  		// The slice header contains the address of the underlying array.
    86  		// We only need to cast it to a pointer.
    87  		unsafe.Pointer(hdrx.Data),
    88  		unsafe.Pointer(hdry.Data),
    89  		// The C function expects pointers to the result and the length of the arrays.
    90  		unsafe.Pointer(&res),
    91  		unsafe.Pointer(&l))
    92  
    93  	return res
    94  }