github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/hnsw_stress_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hnsw 13 14 import ( 15 "context" 16 "encoding/binary" 17 "fmt" 18 "io" 19 "log" 20 "math" 21 "math/rand" 22 "os" 23 "sync" 24 "testing" 25 "time" 26 27 "github.com/sirupsen/logrus" 28 enterrors "github.com/weaviate/weaviate/entities/errors" 29 30 "github.com/pkg/errors" 31 "github.com/stretchr/testify/require" 32 ) 33 34 const ( 35 vectorSize = 128 36 vectorsPerGoroutine = 100 37 parallelGoroutines = 100 38 parallelSearchGoroutines = 8 39 ) 40 41 func idVector(ctx context.Context, id uint64) ([]float32, error) { 42 vector := make([]float32, vectorSize) 43 for i := 0; i < vectorSize; i++ { 44 vector[i] = float32(id) 45 } 46 return vector, nil 47 } 48 49 func idVectorSize(size int) func(ctx context.Context, id uint64) ([]float32, error) { 50 return func(ctx context.Context, id uint64) ([]float32, error) { 51 vector := make([]float32, size) 52 for i := 0; i < size; i++ { 53 vector[i] = float32(id) 54 } 55 return vector, nil 56 } 57 } 58 59 func float32FromBytes(bytes []byte) float32 { 60 bits := binary.LittleEndian.Uint32(bytes) 61 float := math.Float32frombits(bits) 62 return float 63 } 64 65 func int32FromBytes(bytes []byte) int { 66 return int(binary.LittleEndian.Uint32(bytes)) 67 } 68 69 func TestHnswStress(t *testing.T) { 70 siftFile := "datasets/ann-benchmarks/siftsmall/siftsmall_base.fvecs" 71 siftFileQuery := "datasets/ann-benchmarks/siftsmall/sift_query.fvecs" 72 _, err2 := os.Stat(siftFileQuery) 73 if _, err := os.Stat(siftFile); err != nil || err2 != nil { 74 if !*download { 75 t.Skip(`Sift data needs to be present. 76 Run test with -download to automatically download the dataset. 77 Ex: go test -v -run TestHnswStress . -download 78 `) 79 } 80 downloadDatasetFile(t, siftFile) 81 } 82 vectors := readSiftFloat(siftFile, parallelGoroutines*vectorsPerGoroutine) 83 vectorsQuery := readSiftFloat(siftFile, parallelGoroutines*vectorsPerGoroutine) 84 85 t.Run("Insert and search and maybe delete", func(t *testing.T) { 86 for n := 0; n < 1; n++ { // increase if you don't want to reread SIFT for every run 87 wg := sync.WaitGroup{} 88 index := createEmptyHnswIndexForTests(t, idVector) 89 for k := 0; k < parallelGoroutines; k++ { 90 wg.Add(2) 91 goroutineIndex := k * vectorsPerGoroutine 92 go func() { 93 for i := 0; i < vectorsPerGoroutine; i++ { 94 95 err := index.Add(uint64(goroutineIndex+i), vectors[goroutineIndex+i]) 96 require.Nil(t, err) 97 } 98 wg.Done() 99 }() 100 101 go func() { 102 for i := 0; i < vectorsPerGoroutine; i++ { 103 for j := 0; j < 5; j++ { // try a couple of times to delete if found 104 _, dists, err := index.SearchByVector(vectors[goroutineIndex+i], 0, nil) 105 require.Nil(t, err) 106 107 if len(dists) > 0 && dists[0] == 0 { 108 err := index.Delete(uint64(goroutineIndex + i)) 109 require.Nil(t, err) 110 break 111 } else { 112 continue 113 } 114 } 115 } 116 wg.Done() 117 }() 118 } 119 wg.Wait() 120 } 121 }) 122 123 t.Run("Insert and delete", func(t *testing.T) { 124 for i := 0; i < 1; i++ { // increase if you don't want to reread SIFT for every run 125 wg := sync.WaitGroup{} 126 index := createEmptyHnswIndexForTests(t, idVector) 127 for k := 0; k < parallelGoroutines; k++ { 128 wg.Add(1) 129 goroutineIndex := k * vectorsPerGoroutine 130 go func() { 131 for i := 0; i < vectorsPerGoroutine; i++ { 132 133 err := index.Add(uint64(goroutineIndex+i), vectors[goroutineIndex+i]) 134 require.Nil(t, err) 135 err = index.Delete(uint64(goroutineIndex + i)) 136 require.Nil(t, err) 137 138 } 139 wg.Done() 140 }() 141 142 } 143 wg.Wait() 144 145 } 146 }) 147 148 t.Run("Concurrent search", func(t *testing.T) { 149 index := createEmptyHnswIndexForTests(t, idVector) 150 // add elements 151 for k, vec := range vectors { 152 err := index.Add(uint64(k), vec) 153 require.Nil(t, err) 154 } 155 156 vectorsPerGoroutineSearch := len(vectorsQuery) / parallelSearchGoroutines 157 wg := sync.WaitGroup{} 158 159 for i := 0; i < 10; i++ { // increase if you don't want to reread SIFT for every run 160 for k := 0; k < parallelSearchGoroutines; k++ { 161 wg.Add(1) 162 k := k 163 go func() { 164 goroutineIndex := k * vectorsPerGoroutineSearch 165 for j := 0; j < vectorsPerGoroutineSearch; j++ { 166 _, _, err := index.SearchByVector(vectors[goroutineIndex+j], 0, nil) 167 require.Nil(t, err) 168 169 } 170 wg.Done() 171 }() 172 } 173 } 174 wg.Wait() 175 }) 176 177 t.Run("Concurrent deletes", func(t *testing.T) { 178 for i := 0; i < 10; i++ { // increase if you don't want to reread SIFT for every run 179 wg := sync.WaitGroup{} 180 181 index := createEmptyHnswIndexForTests(t, idVector) 182 deleteIds := make([]uint64, 50) 183 for j := 0; j < len(deleteIds); j++ { 184 err := index.Add(uint64(j), vectors[j]) 185 require.Nil(t, err) 186 deleteIds[j] = uint64(j) 187 } 188 wg.Add(2) 189 190 go func() { 191 err := index.Delete(deleteIds[25:]...) 192 require.Nil(t, err) 193 wg.Done() 194 }() 195 go func() { 196 err := index.Delete(deleteIds[:24]...) 197 require.Nil(t, err) 198 wg.Done() 199 }() 200 201 wg.Wait() 202 203 time.Sleep(time.Microsecond * 100) 204 index.Lock() 205 require.NotNil(t, index.nodes[24]) 206 index.Unlock() 207 208 } 209 }) 210 211 t.Run("Random operations", func(t *testing.T) { 212 for i := 0; i < 1; i++ { // increase if you don't want to reread SIFT for every run 213 index := createEmptyHnswIndexForTests(t, idVector) 214 215 var inserted struct { 216 sync.Mutex 217 ids []uint64 218 set map[uint64]struct{} 219 } 220 inserted.set = make(map[uint64]struct{}) 221 222 claimUnusedID := func() (uint64, bool) { 223 inserted.Lock() 224 defer inserted.Unlock() 225 226 if len(inserted.ids) == len(vectors) { 227 return 0, false 228 } 229 230 try := 0 231 for { 232 id := uint64(rand.Intn(len(vectors))) 233 if _, ok := inserted.set[id]; !ok { 234 inserted.ids = append(inserted.ids, id) 235 inserted.set[id] = struct{}{} 236 return id, true 237 } 238 239 try++ 240 if try > 50 { 241 log.Printf("[WARN] tried %d times, retrying...\n", try) 242 } 243 } 244 } 245 246 getInsertedIDs := func(n int) []uint64 { 247 inserted.Lock() 248 defer inserted.Unlock() 249 250 if len(inserted.ids) < n { 251 return nil 252 } 253 254 if n > len(inserted.ids) { 255 n = len(inserted.ids) 256 } 257 258 ids := make([]uint64, n) 259 copy(ids, inserted.ids[:n]) 260 261 return ids 262 } 263 264 removeInsertedIDs := func(ids ...uint64) { 265 inserted.Lock() 266 defer inserted.Unlock() 267 268 for _, id := range ids { 269 delete(inserted.set, id) 270 for i, insertedID := range inserted.ids { 271 if insertedID == id { 272 inserted.ids = append(inserted.ids[:i], inserted.ids[i+1:]...) 273 break 274 } 275 } 276 } 277 } 278 279 ops := []func(){ 280 // Add 281 func() { 282 id, ok := claimUnusedID() 283 if !ok { 284 return 285 } 286 287 err := index.Add(id, vectors[id]) 288 require.Nil(t, err) 289 }, 290 // Delete 291 func() { 292 // delete 5% of the time 293 if rand.Int31()%20 == 0 { 294 return 295 } 296 297 ids := getInsertedIDs(rand.Intn(100) + 1) 298 299 err := index.Delete(ids...) 300 require.Nil(t, err) 301 302 removeInsertedIDs(ids...) 303 }, 304 // Search 305 func() { 306 // search 50% of the time 307 if rand.Int31()%2 == 0 { 308 return 309 } 310 311 id := rand.Intn(len(vectors)) 312 313 _, _, err := index.SearchByVector(vectors[id], 0, nil) 314 require.Nil(t, err) 315 }, 316 } 317 318 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) 319 defer cancel() 320 321 g, ctx := enterrors.NewErrorGroupWithContextWrapper(logrus.New(), ctx) 322 323 // run parallelGoroutines goroutines 324 for i := 0; i < parallelGoroutines; i++ { 325 g.Go(func() error { 326 for { 327 select { 328 case <-ctx.Done(): 329 return ctx.Err() 330 default: 331 ops[rand.Intn(len(ops))]() 332 } 333 } 334 }) 335 } 336 337 g.Wait() 338 } 339 }) 340 } 341 342 func readSiftFloat(file string, maxObjects int) [][]float32 { 343 var vectors [][]float32 344 345 f, err := os.Open(file) 346 if err != nil { 347 panic(errors.Wrap(err, "Could not open SIFT file")) 348 } 349 defer f.Close() 350 351 fi, err := f.Stat() 352 if err != nil { 353 panic(errors.Wrap(err, "Could not get SIFT file properties")) 354 } 355 fileSize := fi.Size() 356 if fileSize < 1000000 { 357 panic("The file is only " + fmt.Sprint(fileSize) + " bytes long. Did you forgot to install git lfs?") 358 } 359 360 // The sift data is a binary file containing floating point vectors 361 // For each entry, the first 4 bytes is the length of the vector (in number of floats, not in bytes) 362 // which is followed by the vector data with vector length * 4 bytes. 363 // |-length-vec1 (4bytes)-|-Vec1-data-(4*length-vector-1 bytes)-|-length-vec2 (4bytes)-|-Vec2-data-(4*length-vector-2 bytes)-| 364 // The vector length needs to be converted from bytes to int 365 // The vector data needs to be converted from bytes to float 366 // Note that the vector entries are of type float but are integer numbers eg 2.0 367 bytesPerF := 4 368 vectorBytes := make([]byte, bytesPerF+vectorSize*bytesPerF) 369 for i := 0; i >= 0; i++ { 370 _, err = f.Read(vectorBytes) 371 if err == io.EOF { 372 break 373 } else if err != nil { 374 panic(err) 375 } 376 if int32FromBytes(vectorBytes[0:bytesPerF]) != vectorSize { 377 panic("Each vector must have 128 entries.") 378 } 379 vectorFloat := make([]float32, 0, vectorSize) 380 for j := 0; j < vectorSize; j++ { 381 start := (j + 1) * bytesPerF // first bytesPerF are length of vector 382 vectorFloat = append(vectorFloat, float32FromBytes(vectorBytes[start:start+bytesPerF])) 383 } 384 385 vectors = append(vectors, vectorFloat) 386 387 if i >= maxObjects { 388 break 389 } 390 } 391 if len(vectors) < maxObjects { 392 panic("Could not load all elements.") 393 } 394 395 return vectors 396 }