github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/recall_geo_spatial_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build integrationTestSlow && !race 13 // +build integrationTestSlow,!race 14 15 package hnsw 16 17 import ( 18 "context" 19 "fmt" 20 "math/rand" 21 "runtime" 22 "sort" 23 "sync" 24 "testing" 25 "time" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 30 "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers" 31 "github.com/weaviate/weaviate/entities/cyclemanager" 32 ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 33 ) 34 35 func TestRecallGeo(t *testing.T) { 36 size := 10000 37 queries := 100 38 efConstruction := 128 39 maxNeighbors := 64 40 41 vectors := make([][]float32, size) 42 queryVectors := make([][]float32, queries) 43 var vectorIndex *hnsw 44 45 t.Run("generate random vectors", func(t *testing.T) { 46 fmt.Printf("generating %d vectors", size) 47 for i := 0; i < size; i++ { 48 lat, lon := randLatLon() 49 vectors[i] = []float32{lat, lon} 50 } 51 fmt.Printf("done\n") 52 53 fmt.Printf("generating %d search queries", queries) 54 for i := 0; i < queries; i++ { 55 lat, lon := randLatLon() 56 queryVectors[i] = []float32{lat, lon} 57 } 58 fmt.Printf("done\n") 59 }) 60 61 t.Run("importing into hnsw", func(t *testing.T) { 62 fmt.Printf("importing into hnsw\n") 63 index, err := New(Config{ 64 RootPath: "doesnt-matter-as-committlogger-is-mocked-out", 65 ID: "recallbenchmark", 66 MakeCommitLoggerThunk: MakeNoopCommitLogger, 67 DistanceProvider: distancer.NewGeoProvider(), 68 VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) { 69 return vectors[int(id)], nil 70 }, 71 }, ent.UserConfig{ 72 MaxConnections: maxNeighbors, 73 EFConstruction: efConstruction, 74 }, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), 75 cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t)) 76 77 require.Nil(t, err) 78 vectorIndex = index 79 80 workerCount := runtime.GOMAXPROCS(0) 81 jobsForWorker := make([][][]float32, workerCount) 82 83 for i, vec := range vectors { 84 workerID := i % workerCount 85 jobsForWorker[workerID] = append(jobsForWorker[workerID], vec) 86 } 87 88 beforeImport := time.Now() 89 wg := &sync.WaitGroup{} 90 for workerID, jobs := range jobsForWorker { 91 wg.Add(1) 92 go func(workerID int, myJobs [][]float32) { 93 defer wg.Done() 94 for i, vec := range myJobs { 95 originalIndex := (i * workerCount) + workerID 96 err := vectorIndex.Add(uint64(originalIndex), vec) 97 require.Nil(t, err) 98 } 99 }(workerID, jobs) 100 } 101 102 wg.Wait() 103 fmt.Printf("import took %s\n", time.Since(beforeImport)) 104 }) 105 106 t.Run("with k=10", func(t *testing.T) { 107 k := 10 108 109 var relevant int 110 var retrieved int 111 112 var times time.Duration 113 114 for i := 0; i < queries; i++ { 115 controlList := bruteForce(vectors, queryVectors[i], k) 116 before := time.Now() 117 results, _, err := vectorIndex.knnSearchByVector(queryVectors[i], k, 800, nil) 118 times += time.Since(before) 119 120 require.Nil(t, err) 121 122 retrieved += k 123 relevant += matchesInLists(controlList, results) 124 } 125 126 recall := float32(relevant) / float32(retrieved) 127 fmt.Printf("recall is %f\n", recall) 128 fmt.Printf("avg search time for k=%d is %s\n", k, times/time.Duration(queries)) 129 assert.True(t, recall >= 0.99) 130 }) 131 132 t.Run("with max dist set", func(t *testing.T) { 133 distances := []float32{ 134 0.1, 135 1, 136 10, 137 100, 138 1000, 139 2000, 140 5000, 141 7500, 142 10000, 143 12500, 144 15000, 145 20000, 146 35000, 147 100000, // larger than the circumference of the earth, should contain all 148 } 149 150 for _, maxDist := range distances { 151 t.Run(fmt.Sprintf("with maxDist=%f", maxDist), func(t *testing.T) { 152 var relevant int 153 var retrieved int 154 155 var times time.Duration 156 157 for i := 0; i < queries; i++ { 158 controlList := bruteForceMaxDist(vectors, queryVectors[i], maxDist) 159 before := time.Now() 160 results, err := vectorIndex.KnnSearchByVectorMaxDist(queryVectors[i], maxDist, 800, nil) 161 times += time.Since(before) 162 require.Nil(t, err) 163 164 retrieved += len(results) 165 relevant += matchesInLists(controlList, results) 166 } 167 168 if relevant == 0 { 169 // skip, as we risk dividing by zero, if both relevant and retrieved 170 // are zero, however, we want to fail with a divide-by-zero if only 171 // retrieved is 0 and relevant was more than 0 172 return 173 } 174 recall := float32(relevant) / float32(retrieved) 175 fmt.Printf("recall is %f\n", recall) 176 fmt.Printf("avg search time for maxDist=%f is %s\n", maxDist, times/time.Duration(queries)) 177 assert.True(t, recall >= 0.99) 178 }) 179 } 180 }) 181 } 182 183 func matchesInLists(control []uint64, results []uint64) int { 184 desired := map[uint64]struct{}{} 185 for _, relevant := range control { 186 desired[relevant] = struct{}{} 187 } 188 189 var matches int 190 for _, candidate := range results { 191 _, ok := desired[candidate] 192 if ok { 193 matches++ 194 } 195 } 196 197 return matches 198 } 199 200 func bruteForce(vectors [][]float32, query []float32, k int) []uint64 { 201 type distanceAndIndex struct { 202 distance float32 203 index uint64 204 } 205 206 distances := make([]distanceAndIndex, len(vectors)) 207 208 distancer := distancer.NewGeoProvider().New(query) 209 for i, vec := range vectors { 210 dist, _, _ := distancer.Distance(vec) 211 distances[i] = distanceAndIndex{ 212 index: uint64(i), 213 distance: dist, 214 } 215 } 216 217 sort.Slice(distances, func(a, b int) bool { 218 return distances[a].distance < distances[b].distance 219 }) 220 221 if len(distances) < k { 222 k = len(distances) 223 } 224 225 out := make([]uint64, k) 226 for i := 0; i < k; i++ { 227 out[i] = distances[i].index 228 } 229 230 return out 231 } 232 233 func bruteForceMaxDist(vectors [][]float32, query []float32, maxDist float32) []uint64 { 234 type distanceAndIndex struct { 235 distance float32 236 index uint64 237 } 238 239 distances := make([]distanceAndIndex, len(vectors)) 240 241 distancer := distancer.NewGeoProvider().New(query) 242 for i, vec := range vectors { 243 dist, _, _ := distancer.Distance(vec) 244 distances[i] = distanceAndIndex{ 245 index: uint64(i), 246 distance: dist, 247 } 248 } 249 250 sort.Slice(distances, func(a, b int) bool { 251 return distances[a].distance < distances[b].distance 252 }) 253 254 out := make([]uint64, len(distances)) 255 i := 0 256 for _, elem := range distances { 257 if elem.distance > maxDist { 258 break 259 } 260 out[i] = distances[i].index 261 i++ 262 } 263 264 return out[:i] 265 } 266 267 func randLatLon() (float32, float32) { 268 maxLat := float32(90.0) 269 minLat := float32(-90.0) 270 maxLon := float32(180) 271 minLon := float32(-180) 272 273 lat := minLat + (maxLat-minLat)*rand.Float32() 274 lon := minLon + (maxLon-minLon)*rand.Float32() 275 return lat, lon 276 }