github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/testinghelpers/helpers.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package testinghelpers
    13  
    14  import (
    15  	"encoding/binary"
    16  	"encoding/gob"
    17  	"fmt"
    18  	"io"
    19  	"math"
    20  	"math/rand"
    21  	"os"
    22  	"sort"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/pkg/errors"
    27  	"github.com/sirupsen/logrus/hooks/test"
    28  	"github.com/stretchr/testify/require"
    29  	"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
    30  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    31  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    32  	"github.com/weaviate/weaviate/entities/cyclemanager"
    33  )
    34  
    35  type DistanceFunction func([]float32, []float32) float32
    36  
    37  func getRandomSeed() *rand.Rand {
    38  	return rand.New(rand.NewSource(time.Now().UnixNano()))
    39  }
    40  
    41  func int32FromBytes(bytes []byte) int {
    42  	return int(binary.LittleEndian.Uint32(bytes))
    43  }
    44  
    45  func float32FromBytes(bytes []byte) float32 {
    46  	bits := binary.LittleEndian.Uint32(bytes)
    47  	float := math.Float32frombits(bits)
    48  	return float
    49  }
    50  
    51  func readSiftFloat(file string, maxObjects int, vectorLengthFloat int) [][]float32 {
    52  	f, err := os.Open(file)
    53  	if err != nil {
    54  		panic(errors.Wrap(err, "Could not open SIFT file"))
    55  	}
    56  	defer f.Close()
    57  
    58  	fi, err := f.Stat()
    59  	if err != nil {
    60  		panic(errors.Wrap(err, "Could not get SIFT file properties"))
    61  	}
    62  	fileSize := fi.Size()
    63  	if fileSize < 1000000 {
    64  		panic("The file is only " + fmt.Sprint(fileSize) + " bytes long. Did you forgot to install git lfs?")
    65  	}
    66  
    67  	// The sift data is a binary file containing floating point vectors
    68  	// For each entry, the first 4 bytes is the length of the vector (in number of floats, not in bytes)
    69  	// which is followed by the vector data with vector length * 4 bytes.
    70  	// |-length-vec1 (4bytes)-|-Vec1-data-(4*length-vector-1 bytes)-|-length-vec2 (4bytes)-|-Vec2-data-(4*length-vector-2 bytes)-|
    71  	// The vector length needs to be converted from bytes to int
    72  	// The vector data needs to be converted from bytes to float
    73  	// Note that the vector entries are of type float but are integer numbers eg 2.0
    74  	bytesPerF := 4
    75  	objects := make([][]float32, maxObjects)
    76  	vectorBytes := make([]byte, bytesPerF+vectorLengthFloat*bytesPerF)
    77  	for i := 0; i >= 0; i++ {
    78  		_, err = f.Read(vectorBytes)
    79  		if err == io.EOF {
    80  			break
    81  		} else if err != nil {
    82  			panic(err)
    83  		}
    84  		if int32FromBytes(vectorBytes[0:bytesPerF]) != vectorLengthFloat {
    85  			panic("Each vector must have 128 entries.")
    86  		}
    87  		vectorFloat := []float32{}
    88  		for j := 0; j < vectorLengthFloat; j++ {
    89  			start := (j + 1) * bytesPerF // first bytesPerF are length of vector
    90  			vectorFloat = append(vectorFloat, float32FromBytes(vectorBytes[start:start+bytesPerF]))
    91  		}
    92  		objects[i] = vectorFloat
    93  
    94  		if i >= maxObjects-1 {
    95  			break
    96  		}
    97  	}
    98  
    99  	return objects
   100  }
   101  
   102  func ReadSiftVecsFrom(path string, size int, dimensions int) [][]float32 {
   103  	fmt.Printf("generating %d vectors...", size)
   104  	vectors := readSiftFloat(path, size, dimensions)
   105  	fmt.Printf(" done\n")
   106  	return vectors
   107  }
   108  
   109  func RandomVecs(size int, queriesSize int, dimensions int) ([][]float32, [][]float32) {
   110  	fmt.Printf("generating %d vectors...\n", size+queriesSize)
   111  	r := getRandomSeed()
   112  	vectors := make([][]float32, 0, size)
   113  	queries := make([][]float32, 0, queriesSize)
   114  	for i := 0; i < size; i++ {
   115  		vectors = append(vectors, genVector(r, dimensions))
   116  	}
   117  	for i := 0; i < queriesSize; i++ {
   118  		queries = append(queries, genVector(r, dimensions))
   119  	}
   120  	return vectors, queries
   121  }
   122  
   123  func genVector(r *rand.Rand, dimensions int) []float32 {
   124  	vector := make([]float32, 0, dimensions)
   125  	for i := 0; i < dimensions; i++ {
   126  		// Some distances like dot could produce negative values when the vectors have negative values
   127  		// This change will not affect anything when using a distance like l2, but will cover some bugs
   128  		// when using distances like dot
   129  		vector = append(vector, r.Float32()*2-1)
   130  	}
   131  	return vector
   132  }
   133  
   134  func Normalize(vectors [][]float32) {
   135  	for i := range vectors {
   136  		vectors[i] = distancer.Normalize(vectors[i])
   137  	}
   138  }
   139  
   140  func ReadVecs(size int, queriesSize int, dimensions int, db string, path ...string) ([][]float32, [][]float32) {
   141  	fmt.Printf("generating %d vectors...", size+queriesSize)
   142  	uri := db
   143  	if len(path) > 0 {
   144  		uri = fmt.Sprintf("%s/%s", path[0], uri)
   145  	}
   146  	vectors := readSiftFloat(fmt.Sprintf("%s/%s_base.fvecs", uri, db), size, dimensions)
   147  	queries := readSiftFloat(fmt.Sprintf("%s/%s_query.fvecs", uri, db), queriesSize, dimensions)
   148  	fmt.Printf(" done\n")
   149  	return vectors, queries
   150  }
   151  
   152  func ReadQueries(queriesSize int) [][]float32 {
   153  	fmt.Printf("generating %d vectors...", queriesSize)
   154  	queries := readSiftFloat("sift/sift_query.fvecs", queriesSize, 128)
   155  	fmt.Printf(" done\n")
   156  	return queries
   157  }
   158  
   159  func BruteForce(vectors [][]float32, query []float32, k int, distance DistanceFunction) ([]uint64, []float32) {
   160  	type distanceAndIndex struct {
   161  		distance float32
   162  		index    uint64
   163  	}
   164  
   165  	distances := make([]distanceAndIndex, len(vectors))
   166  
   167  	compressionhelpers.Concurrently(uint64(len(vectors)), func(i uint64) {
   168  		dist := distance(query, vectors[i])
   169  		distances[i] = distanceAndIndex{
   170  			index:    uint64(i),
   171  			distance: dist,
   172  		}
   173  	})
   174  
   175  	sort.Slice(distances, func(a, b int) bool {
   176  		return distances[a].distance < distances[b].distance
   177  	})
   178  
   179  	if len(distances) < k {
   180  		k = len(distances)
   181  	}
   182  
   183  	out := make([]uint64, k)
   184  	dists := make([]float32, k)
   185  	for i := 0; i < k; i++ {
   186  		out[i] = distances[i].index
   187  		dists[i] = distances[i].distance
   188  	}
   189  
   190  	return out, dists
   191  }
   192  
   193  func BuildTruths(queriesSize int, vectorsSize int, queries [][]float32, vectors [][]float32, k int, distance DistanceFunction, path ...string) [][]uint64 {
   194  	uri := "sift/sift_truths%d.%d.gob"
   195  	if len(path) > 0 {
   196  		uri = fmt.Sprintf("%s/%s", path[0], uri)
   197  	}
   198  	fileName := fmt.Sprintf(uri, k, vectorsSize)
   199  	truths := make([][]uint64, queriesSize)
   200  
   201  	if _, err := os.Stat(fileName); err == nil {
   202  		return loadTruths(fileName, queriesSize, k)
   203  	}
   204  
   205  	compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) {
   206  		truths[i], _ = BruteForce(vectors, queries[i], k, distance)
   207  	})
   208  
   209  	f, err := os.Create(fileName)
   210  	if err != nil {
   211  		panic(errors.Wrap(err, "Could not open file"))
   212  	}
   213  
   214  	defer f.Close()
   215  	enc := gob.NewEncoder(f)
   216  	err = enc.Encode(truths)
   217  	if err != nil {
   218  		panic(errors.Wrap(err, "Could not encode truths"))
   219  	}
   220  	return truths
   221  }
   222  
   223  func loadTruths(fileName string, queriesSize int, k int) [][]uint64 {
   224  	f, err := os.Open(fileName)
   225  	if err != nil {
   226  		panic(errors.Wrap(err, "Could not open truths file"))
   227  	}
   228  	defer f.Close()
   229  
   230  	truths := make([][]uint64, queriesSize)
   231  	cDec := gob.NewDecoder(f)
   232  	err = cDec.Decode(&truths)
   233  	if err != nil {
   234  		panic(errors.Wrap(err, "Could not decode truths"))
   235  	}
   236  	return truths
   237  }
   238  
   239  func MatchesInLists(control []uint64, results []uint64) uint64 {
   240  	desired := map[uint64]struct{}{}
   241  	for _, relevant := range control {
   242  		desired[relevant] = struct{}{}
   243  	}
   244  
   245  	var matches uint64
   246  	for _, candidate := range results {
   247  		_, ok := desired[candidate]
   248  		if ok {
   249  			matches++
   250  		}
   251  	}
   252  
   253  	return matches
   254  }
   255  
   256  func NewDummyStore(t testing.TB) *lsmkv.Store {
   257  	logger, _ := test.NewNullLogger()
   258  	storeDir := t.TempDir()
   259  	store, err := lsmkv.New(storeDir, storeDir, logger, nil,
   260  		cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop())
   261  	require.Nil(t, err)
   262  	return store
   263  }