github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/recall_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build benchmarkRecall 13 // +build benchmarkRecall 14 15 package hnsw 16 17 import ( 18 "context" 19 "encoding/json" 20 "fmt" 21 "io/ioutil" 22 "runtime" 23 "sync" 24 "testing" 25 "time" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 30 "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers" 31 ) 32 33 func TestRecall(t *testing.T) { 34 efConstruction := 256 35 ef := 256 36 maxNeighbors := 64 37 38 var vectors [][]float32 39 var queries [][]float32 40 var truths [][]uint64 41 var vectorIndex *hnsw 42 43 t.Run("generate random vectors", func(t *testing.T) { 44 vectorsJSON, err := ioutil.ReadFile("recall_vectors.json") 45 require.Nil(t, err) 46 err = json.Unmarshal(vectorsJSON, &vectors) 47 require.Nil(t, err) 48 49 queriesJSON, err := ioutil.ReadFile("recall_queries.json") 50 require.Nil(t, err) 51 err = json.Unmarshal(queriesJSON, &queries) 52 require.Nil(t, err) 53 54 truthsJSON, err := ioutil.ReadFile("recall_truths.json") 55 require.Nil(t, err) 56 err = json.Unmarshal(truthsJSON, &truths) 57 require.Nil(t, err) 58 }) 59 60 t.Run("importing into hnsw", func(t *testing.T) { 61 fmt.Printf("importing into hnsw\n") 62 63 index, err := New(Config{ 64 RootPath: "doesnt-matter-as-committlogger-is-mocked-out", 65 ID: "recallbenchmark", 66 MakeCommitLoggerThunk: MakeNoopCommitLogger, 67 DistanceProvider: distancer.NewCosineDistanceProvider(), 68 VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) { 69 return vectors[int(id)], nil 70 }, 71 }, UserConfig{ 72 MaxConnections: maxNeighbors, 73 EFConstruction: efConstruction, 74 EF: ef, 75 }, testinghelpers.NewDummyStore(t)) 76 require.Nil(t, err) 77 vectorIndex = index 78 79 workerCount := runtime.GOMAXPROCS(0) 80 jobsForWorker := make([][][]float32, workerCount) 81 82 before := time.Now() 83 for i, vec := range vectors { 84 workerID := i % workerCount 85 jobsForWorker[workerID] = append(jobsForWorker[workerID], vec) 86 } 87 88 wg := &sync.WaitGroup{} 89 for workerID, jobs := range jobsForWorker { 90 wg.Add(1) 91 go func(workerID int, myJobs [][]float32) { 92 defer wg.Done() 93 for i, vec := range myJobs { 94 originalIndex := (i * workerCount) + workerID 95 err := vectorIndex.Add(uint64(originalIndex), vec) 96 require.Nil(t, err) 97 } 98 }(workerID, jobs) 99 } 100 101 wg.Wait() 102 fmt.Printf("importing took %s\n", time.Since(before)) 103 }) 104 105 t.Run("inspect a query", func(t *testing.T) { 106 k := 20 107 108 hasDuplicates := 0 109 110 for _, vec := range queries { 111 results, _, err := vectorIndex.SearchByVector(vec, k, nil) 112 require.Nil(t, err) 113 if containsDuplicates(results) { 114 hasDuplicates++ 115 panic("stop") 116 } 117 } 118 119 fmt.Printf("%d out of %d searches contained duplicates", hasDuplicates, len(queries)) 120 }) 121 122 t.Run("with k=10", func(t *testing.T) { 123 k := 10 124 125 var relevant int 126 var retrieved int 127 128 for i := 0; i < len(queries); i++ { 129 results, _, err := vectorIndex.SearchByVector(queries[i], k, nil) 130 require.Nil(t, err) 131 132 retrieved += k 133 relevant += matchesInLists(truths[i], results) 134 } 135 136 recall := float32(relevant) / float32(retrieved) 137 fmt.Printf("recall is %f\n", recall) 138 assert.True(t, recall >= 0.99) 139 }) 140 } 141 142 func matchesInLists(control []uint64, results []uint64) int { 143 desired := map[uint64]struct{}{} 144 for _, relevant := range control { 145 desired[relevant] = struct{}{} 146 } 147 148 var matches int 149 for _, candidate := range results { 150 _, ok := desired[candidate] 151 if ok { 152 matches++ 153 } 154 } 155 156 return matches 157 } 158 159 func containsDuplicates(in []uint64) bool { 160 seen := map[uint64]struct{}{} 161 162 for _, value := range in { 163 if _, ok := seen[value]; ok { 164 return true 165 } 166 seen[value] = struct{}{} 167 } 168 169 return false 170 }