github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/index_queue_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package db 13 14 import ( 15 "context" 16 "fmt" 17 "math/rand" 18 "os" 19 "sort" 20 "sync" 21 "sync/atomic" 22 "testing" 23 "time" 24 25 "github.com/sirupsen/logrus" 26 "github.com/sirupsen/logrus/hooks/test" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 29 "github.com/weaviate/weaviate/adapters/repos/db/indexcheckpoint" 30 "github.com/weaviate/weaviate/adapters/repos/db/lsmkv" 31 "github.com/weaviate/weaviate/adapters/repos/db/vector/common" 32 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw" 33 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 34 "github.com/weaviate/weaviate/entities/cyclemanager" 35 "github.com/weaviate/weaviate/entities/storagestate" 36 ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 37 ) 38 39 func startWorker(t testing.TB, retryInterval ...time.Duration) chan job { 40 t.Helper() 41 ch := make(chan job) 42 t.Cleanup(func() { 43 close(ch) 44 }) 45 46 itv := time.Millisecond 47 if len(retryInterval) > 0 { 48 itv = retryInterval[0] 49 } 50 51 go func() { 52 logger := logrus.New() 53 logger.Level = logrus.ErrorLevel 54 asyncWorker(ch, logger, itv) 55 }() 56 57 return ch 58 } 59 60 func newCheckpointManager(t testing.TB) *indexcheckpoint.Checkpoints { 61 t.Helper() 62 63 return newCheckpointManagerWithDir(t, t.TempDir()) 64 } 65 66 func newCheckpointManagerWithDir(t testing.TB, dir string) *indexcheckpoint.Checkpoints { 67 t.Helper() 68 69 c, err := indexcheckpoint.New(dir, logrus.New()) 70 require.NoError(t, err) 71 72 return c 73 } 74 75 func pushVector(t testing.TB, ctx context.Context, q *IndexQueue, id uint64, vector []float32) { 76 err := q.Push(ctx, vectorDescriptor{ 77 id: id, 78 vector: vector, 79 }) 80 require.NoError(t, err) 81 } 82 83 func randVector(dim int) []float32 { 84 vec := make([]float32, dim) 85 for i := range vec { 86 vec[i] = rand.Float32() 87 } 88 89 return vec 90 } 91 92 func TestIndexQueue(t *testing.T) { 93 ctx, cancel := context.WithCancel(context.Background()) 94 t.Cleanup(cancel) 95 96 os.Setenv("ASYNC_INDEXING", "true") 97 defer os.Unsetenv("ASYNC_INDEXING") 98 99 writeIDs := func(q *IndexQueue, from, to uint64) { 100 vectors := make([]vectorDescriptor, 0, to-from) 101 for i := from; i < to; i++ { 102 vectors = append(vectors, vectorDescriptor{ 103 id: i, 104 vector: []float32{1, 2, 3}, 105 }) 106 } 107 err := q.Push(ctx, vectors...) 108 require.NoError(t, err) 109 } 110 111 getLastUpdate := func(q *IndexQueue) time.Time { 112 fi, err := os.Stat(q.checkpoints.Filename()) 113 require.NoError(t, err) 114 return fi.ModTime() 115 } 116 117 waitForUpdate := func(q *IndexQueue) func(timeout ...time.Duration) bool { 118 lastUpdate := getLastUpdate(q) 119 120 return func(timeout ...time.Duration) bool { 121 start := time.Now() 122 123 if len(timeout) == 0 { 124 timeout = []time.Duration{500 * time.Millisecond} 125 } 126 for { 127 cur := getLastUpdate(q) 128 if cur.Equal(lastUpdate) { 129 if time.Since(start) > timeout[0] { 130 return false 131 } 132 time.Sleep(5 * time.Millisecond) 133 continue 134 } 135 136 lastUpdate = cur 137 return true 138 } 139 } 140 } 141 142 t.Run("pushes to indexer if batch is full", func(t *testing.T) { 143 var idx mockBatchIndexer 144 idsCh := make(chan []uint64, 1) 145 idx.addBatchFn = func(ids []uint64, vector [][]float32) error { 146 idsCh <- ids 147 return nil 148 } 149 150 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 151 BatchSize: 2, 152 }) 153 require.NoError(t, err) 154 defer q.Close() 155 156 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 157 select { 158 case <-idsCh: 159 t.Fatal("should not have been called") 160 case <-time.After(100 * time.Millisecond): 161 } 162 163 pushVector(t, ctx, q, 2, []float32{4, 5, 6}) 164 ids := <-idsCh 165 166 require.Equal(t, []uint64{1, 2}, ids) 167 }) 168 169 t.Run("doesn't index if batch is not null", func(t *testing.T) { 170 var idx mockBatchIndexer 171 called := make(chan struct{}) 172 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 173 called <- struct{}{} 174 return nil 175 } 176 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 177 BatchSize: 100, 178 IndexInterval: time.Microsecond, 179 }) 180 require.NoError(t, err) 181 defer q.Close() 182 183 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 184 select { 185 case <-called: 186 t.Fatal("should not have been called") 187 case <-time.After(100 * time.Millisecond): 188 } 189 190 pushVector(t, ctx, q, 2, []float32{4, 5, 6}) 191 192 select { 193 case <-called: 194 t.Fatal("should not have been called") 195 case <-time.After(100 * time.Millisecond): 196 } 197 }) 198 199 t.Run("retry on indexing error", func(t *testing.T) { 200 var idx mockBatchIndexer 201 i := int32(0) 202 called := make(chan struct{}) 203 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 204 if atomic.AddInt32(&i, 1) < 3 { 205 return fmt.Errorf("indexing error: %d", i) 206 } 207 208 close(called) 209 210 return nil 211 } 212 213 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 214 BatchSize: 1, 215 }) 216 require.NoError(t, err) 217 defer q.Close() 218 219 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 220 <-called 221 }) 222 223 t.Run("merges results from queries", func(t *testing.T) { 224 var idx mockBatchIndexer 225 called := make(chan struct{}) 226 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 227 close(called) 228 return nil 229 } 230 231 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 232 BatchSize: 3, 233 IndexInterval: 100 * time.Millisecond, 234 }) 235 require.NoError(t, err) 236 defer q.Close() 237 238 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 239 pushVector(t, ctx, q, 2, []float32{4, 5, 6}) 240 pushVector(t, ctx, q, 3, []float32{7, 8, 9}) 241 pushVector(t, ctx, q, 4, []float32{1, 2, 3}) 242 243 <-called 244 245 time.Sleep(500 * time.Millisecond) 246 res, _, err := q.SearchByVector([]float32{1, 2, 3}, 2, nil) 247 require.NoError(t, err) 248 require.Equal(t, []uint64{1, 4}, res) 249 }) 250 251 t.Run("search with empty index", func(t *testing.T) { 252 var idx mockBatchIndexer 253 254 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 255 BatchSize: 6, 256 }) 257 require.NoError(t, err) 258 defer q.Close() 259 260 for i := 0; i < 10; i++ { 261 pushVector(t, ctx, q, uint64(i+1), []float32{float32(i) + 1, float32(i) + 2, float32(i) + 3}) 262 } 263 264 res, _, err := q.SearchByVector([]float32{1, 2, 3}, 2, nil) 265 require.NoError(t, err) 266 require.Equal(t, []uint64{1, 2}, res) 267 }) 268 269 t.Run("queue size", func(t *testing.T) { 270 var idx mockBatchIndexer 271 closeCh := make(chan struct{}) 272 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 273 <-closeCh 274 return nil 275 } 276 277 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 278 BatchSize: 5, 279 }) 280 require.NoError(t, err) 281 defer q.Close() 282 283 for i := uint64(0); i < 101; i++ { 284 pushVector(t, ctx, q, i+1, []float32{1, 2, 3}) 285 } 286 287 time.Sleep(100 * time.Millisecond) 288 require.EqualValues(t, 101, q.Size()) 289 close(closeCh) 290 }) 291 292 t.Run("deletion", func(t *testing.T) { 293 var idx mockBatchIndexer 294 var count int32 295 indexingDone := make(chan struct{}) 296 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 297 if atomic.AddInt32(&count, 1) == 5 { 298 close(indexingDone) 299 } 300 301 return nil 302 } 303 304 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 305 BatchSize: 4, 306 IndexInterval: 100 * time.Millisecond, 307 }) 308 require.NoError(t, err) 309 defer q.Close() 310 311 for i := uint64(0); i < 20; i++ { 312 pushVector(t, ctx, q, i, []float32{1, 2, 3}) 313 } 314 315 err = q.Delete(5, 10, 15) 316 require.NoError(t, err) 317 318 wait := waitForUpdate(q) 319 <-indexingDone 320 321 // wait for the checkpoint file to be written to disk 322 wait() 323 324 // check what has been indexed 325 require.Equal(t, []uint64{0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19}, idx.IDs()) 326 327 // the "deleted" mask should be empty 328 q.queue.deleted.Lock() 329 require.Empty(t, q.queue.deleted.m) 330 q.queue.deleted.Unlock() 331 332 // now delete something that's already indexed 333 err = q.Delete(0, 4, 8) 334 require.NoError(t, err) 335 336 // the "deleted" mask should still be empty 337 q.queue.deleted.Lock() 338 require.Empty(t, q.queue.deleted.m) 339 q.queue.deleted.Unlock() 340 341 // check what's in the index 342 require.Equal(t, []uint64{1, 2, 3, 6, 7, 9, 11, 12, 13, 14, 16, 17, 18, 19}, idx.IDs()) 343 344 // delete something that's not indexed yet 345 err = q.Delete(20, 21, 22) 346 require.NoError(t, err) 347 348 // the "deleted" mask should contain the deleted ids 349 q.queue.deleted.Lock() 350 var ids []int 351 for id := range q.queue.deleted.m { 352 ids = append(ids, int(id)) 353 } 354 q.queue.deleted.Unlock() 355 sort.Ints(ids) 356 require.Equal(t, []int{20, 21, 22}, ids) 357 }) 358 359 t.Run("brute force upper limit", func(t *testing.T) { 360 var idx mockBatchIndexer 361 362 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 363 BatchSize: 1000, 364 BruteForceSearchLimit: 2, 365 }) 366 require.NoError(t, err) 367 defer q.Close() 368 369 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 370 pushVector(t, ctx, q, 2, []float32{4, 5, 6}) 371 pushVector(t, ctx, q, 3, []float32{7, 8, 9}) 372 pushVector(t, ctx, q, 4, []float32{1, 2, 3}) 373 374 res, _, err := q.SearchByVector([]float32{7, 8, 9}, 2, nil) 375 require.NoError(t, err) 376 // despite having 4 vectors in the queue 377 // only the first two are used for brute force search 378 require.Equal(t, []uint64{2, 1}, res) 379 }) 380 381 t.Run("stores a safe checkpoint", func(t *testing.T) { 382 var idx mockBatchIndexer 383 384 dir := t.TempDir() 385 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManagerWithDir(t, dir), IndexQueueOptions{ 386 BatchSize: 5, 387 IndexInterval: time.Hour, 388 }) 389 require.NoError(t, err) 390 defer q.Close() 391 392 wait := waitForUpdate(q) 393 writeIDs(q, 5, 7) // [5, 6] 394 writeIDs(q, 9, 13) // [5, 6, 9, 10, 11], [12] 395 writeIDs(q, 0, 5) // [5, 6, 9, 10, 11], [12, 0, 1, 2, 3], [4] 396 time.Sleep(100 * time.Millisecond) 397 before, exists, err := q.checkpoints.Get("1", "") 398 require.NoError(t, err) 399 require.False(t, exists) 400 q.pushToWorkers(-1, false) 401 // the checkpoint should be: 0, then 0 402 // the cursor should not be updated 403 wait(100 * time.Millisecond) 404 after, exists, err := q.checkpoints.Get("1", "") 405 require.NoError(t, err) 406 require.True(t, exists) 407 require.Equal(t, before, after) 408 409 writeIDs(q, 15, 25) // [4, 15, 16, 17, 18], [19, 20, 21, 22, 23], [24] 410 writeIDs(q, 30, 40) // [4, 15, 16, 17, 18], [19, 20, 21, 22, 23], [24, 30, 31, 32, 33], [34, 35, 36, 37, 38], [39] 411 time.Sleep(100 * time.Millisecond) 412 // the checkpoint should be: 0, then 4, then 14, then 29 413 q.pushToWorkers(-1, false) 414 // 0 415 wait() 416 // 4 417 wait() 418 // 14 419 wait() 420 // 29 421 wait() 422 v, exists, err := q.checkpoints.Get("1", "") 423 require.NoError(t, err) 424 require.True(t, exists) 425 require.Equal(t, 29, int(v)) 426 }) 427 428 t.Run("stale vectors", func(t *testing.T) { 429 var idx mockBatchIndexer 430 closeCh := make(chan struct{}) 431 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 432 close(closeCh) 433 return nil 434 } 435 436 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 437 BatchSize: 5, 438 StaleTimeout: 100 * time.Millisecond, 439 IndexInterval: 10 * time.Millisecond, 440 }) 441 require.NoError(t, err) 442 defer q.Close() 443 444 for i := uint64(0); i < 3; i++ { 445 pushVector(t, ctx, q, i+1, []float32{1, 2, 3}) 446 } 447 448 select { 449 case <-closeCh: 450 case <-time.After(500 * time.Millisecond): 451 t.Fatal("should have been indexed after 100ms") 452 } 453 454 require.EqualValues(t, []uint64{1, 2, 3}, idx.IDs()) 455 }) 456 457 t.Run("updates the shard state", func(t *testing.T) { 458 var idx mockBatchIndexer 459 indexed := make(chan struct{}) 460 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 461 close(indexed) 462 return nil 463 } 464 465 updated := make(chan string) 466 shard := mockShard{ 467 compareAndSwapStatusFn: func(old, new string) (storagestate.Status, error) { 468 updated <- new 469 return storagestate.Status(new), nil 470 }, 471 } 472 473 q, err := NewIndexQueue("1", "", &shard, &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 474 BatchSize: 2, 475 IndexInterval: 100 * time.Millisecond, 476 }) 477 require.NoError(t, err) 478 defer q.Close() 479 480 for i := uint64(0); i < 2; i++ { 481 pushVector(t, ctx, q, i+1, []float32{1, 2, 3}) 482 } 483 484 select { 485 case newState := <-updated: 486 require.Equal(t, storagestate.StatusIndexing.String(), newState) 487 case <-time.After(200 * time.Millisecond): 488 t.Fatal("shard state should have been updated after 100ms") 489 } 490 491 select { 492 case <-indexed: 493 case <-time.After(200 * time.Millisecond): 494 t.Fatal("should have been indexed after 100ms") 495 } 496 497 select { 498 case newState := <-updated: 499 require.Equal(t, storagestate.StatusReady.String(), newState) 500 case <-time.After(200 * time.Millisecond): 501 t.Fatal("shard state should have been updated after 100ms") 502 } 503 }) 504 505 t.Run("close waits for indexing to be done", func(t *testing.T) { 506 var idx mockBatchIndexer 507 var count int 508 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 509 <-time.After(10 * time.Millisecond) 510 count++ 511 return nil 512 } 513 514 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 515 BatchSize: 5, 516 }) 517 require.NoError(t, err) 518 defer q.Close() 519 520 for i := uint64(0); i < 100; i++ { 521 pushVector(t, ctx, q, i+1, []float32{1, 2, 3}) 522 } 523 524 q.pushToWorkers(-1, false) 525 q.Close() 526 527 require.EqualValues(t, 20, count) 528 }) 529 530 t.Run("cos: normalized the query vector", func(t *testing.T) { 531 var idx mockBatchIndexer 532 idx.distancerProvider = distancer.NewCosineDistanceProvider() 533 534 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 535 BatchSize: 7, // 300 is not divisible by 7 536 IndexInterval: 100 * time.Second, 537 }) 538 require.NoError(t, err) 539 defer q.Close() 540 541 for i := uint64(0); i < 300; i++ { 542 pushVector(t, ctx, q, i+1, randVector(1536)) 543 } 544 545 q.pushToWorkers(-1, false) 546 547 _, distances, err := q.SearchByVector(randVector(1536), 10, nil) 548 require.NoError(t, err) 549 550 // all distances should be between 0 and 1 551 for _, dist := range distances { 552 require.True(t, dist >= 0 && dist <= 1) 553 } 554 }) 555 556 t.Run("pause/resume indexing", func(t *testing.T) { 557 var idx mockBatchIndexer 558 called := make(chan struct{}) 559 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 560 called <- struct{}{} 561 // simulate work 562 <-time.After(100 * time.Millisecond) 563 return nil 564 } 565 566 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 567 BatchSize: 2, 568 IndexInterval: 10 * time.Millisecond, 569 }) 570 require.NoError(t, err) 571 defer q.Close() 572 573 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 574 pushVector(t, ctx, q, 2, []float32{4, 5, 6}) 575 576 // batch indexed 577 <-called 578 579 // pause indexing: this will block until the batch is indexed 580 q.pauseIndexing() 581 582 // add more vectors 583 pushVector(t, ctx, q, 3, []float32{7, 8, 9}) 584 pushVector(t, ctx, q, 4, []float32{1, 2, 3}) 585 586 // wait enough time to make sure the indexing is not happening 587 <-time.After(200 * time.Millisecond) 588 589 select { 590 case <-called: 591 t.Fatal("should not have been called") 592 default: 593 } 594 595 // resume indexing 596 q.resumeIndexing() 597 598 // wait for the indexing to be done 599 <-called 600 }) 601 602 t.Run("compression", func(t *testing.T) { 603 var idx mockBatchIndexer 604 called := make(chan struct{}) 605 idx.shouldCompress = true 606 idx.threshold = 4 607 idx.alreadyIndexed = 6 608 609 release := make(chan struct{}) 610 idx.onCompressionTurnedOn = func(callback func()) error { 611 go func() { 612 <-release 613 callback() 614 }() 615 616 close(called) 617 return nil 618 } 619 620 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 621 BatchSize: 2, 622 IndexInterval: 10 * time.Millisecond, 623 }) 624 require.NoError(t, err) 625 defer q.Close() 626 627 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 628 pushVector(t, ctx, q, 2, []float32{4, 5, 6}) 629 630 // compression requested 631 <-called 632 633 // indexing should be paused 634 require.True(t, q.paused.Load()) 635 636 // release the compression 637 idx.compressed = true 638 close(release) 639 640 // indexing should be resumed eventually 641 time.Sleep(100 * time.Millisecond) 642 require.False(t, q.paused.Load()) 643 644 indexed := make(chan struct{}) 645 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 646 close(indexed) 647 return nil 648 } 649 650 // add more vectors 651 pushVector(t, ctx, q, 3, []float32{7, 8, 9}) 652 pushVector(t, ctx, q, 4, []float32{1, 2, 3}) 653 654 // indexing should happen 655 <-indexed 656 }) 657 658 t.Run("compression does not occur at the indexing if async is enabled", func(t *testing.T) { 659 vectors := [][]float32{{0, 1, 3, 4, 5, 6}, {0, 1, 3, 4, 5, 6}, {0, 1, 3, 4, 5, 6}} 660 distancer := distancer.NewL2SquaredProvider() 661 uc := ent.UserConfig{} 662 uc.MaxConnections = 112 663 uc.EFConstruction = 112 664 uc.EF = 10 665 uc.VectorCacheMaxObjects = 10e12 666 index, _ := hnsw.New( 667 hnsw.Config{ 668 RootPath: t.TempDir(), 669 ID: "recallbenchmark", 670 MakeCommitLoggerThunk: hnsw.MakeNoopCommitLogger, 671 DistanceProvider: distancer, 672 VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) { 673 return vectors[int(id)], nil 674 }, 675 TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) { 676 copy(container.Slice, vectors[int(id)]) 677 return container.Slice, nil 678 }, 679 }, uc, 680 cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), newDummyStore(t)) 681 defer index.Shutdown(context.Background()) 682 683 q, err := NewIndexQueue("1", "", new(mockShard), index, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 684 BatchSize: 2, 685 IndexInterval: 10 * time.Millisecond, 686 }) 687 require.NoError(t, err) 688 defer q.Close() 689 690 uc.PQ = ent.PQConfig{Enabled: true, Encoder: ent.PQEncoder{Type: "please break...", Distribution: "normal"}} 691 err = index.UpdateUserConfig(uc, func() {}) 692 require.Nil(t, err) 693 }) 694 695 t.Run("sending batch with deleted ids to worker", func(t *testing.T) { 696 var idx mockBatchIndexer 697 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 698 t.Fatal("should not have been called") 699 return nil 700 } 701 702 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 703 BatchSize: 2, 704 IndexInterval: 100 * time.Second, 705 }) 706 require.NoError(t, err) 707 defer q.Close() 708 709 pushVector(t, ctx, q, 0, []float32{1, 2, 3}) 710 pushVector(t, ctx, q, 1, []float32{1, 2, 3}) 711 712 err = q.Delete(0, 1) 713 require.NoError(t, err) 714 715 q.pushToWorkers(-1, true) 716 }) 717 718 t.Run("release twice", func(t *testing.T) { 719 var idx mockBatchIndexer 720 721 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(t), newCheckpointManager(t), IndexQueueOptions{ 722 BatchSize: 10, 723 IndexInterval: time.Hour, // do not index automatically 724 }) 725 require.NoError(t, err) 726 727 for i := uint64(0); i < 35; i++ { 728 pushVector(t, ctx, q, i+1, []float32{1, 2, 3}) 729 } 730 731 chunks := q.queue.borrowChunks(10) 732 require.Equal(t, 3, len(chunks)) 733 734 // release once 735 for _, chunk := range chunks { 736 q.queue.releaseChunk(chunk) 737 } 738 739 // release again 740 for _, chunk := range chunks { 741 q.queue.releaseChunk(chunk) 742 } 743 }) 744 } 745 746 func BenchmarkPush(b *testing.B) { 747 var idx mockBatchIndexer 748 749 idx.addBatchFn = func(id []uint64, vector [][]float32) error { 750 time.Sleep(1 * time.Second) 751 return nil 752 } 753 754 q, err := NewIndexQueue("1", "", new(mockShard), &idx, startWorker(b), newCheckpointManager(b), IndexQueueOptions{ 755 BatchSize: 1000, 756 IndexInterval: 1 * time.Millisecond, 757 }) 758 require.NoError(b, err) 759 defer q.Close() 760 761 vecs := make([]vectorDescriptor, 100) 762 for j := range vecs { 763 vecs[j] = vectorDescriptor{ 764 id: uint64(j), 765 vector: []float32{1, 2, 3}, 766 } 767 } 768 769 b.ResetTimer() 770 for i := 0; i < b.N; i++ { 771 for j := 0; j < 100; j++ { 772 err = q.Push(context.Background(), vecs...) 773 require.NoError(b, err) 774 } 775 } 776 } 777 778 type mockShard struct { 779 compareAndSwapStatusFn func(old, new string) (storagestate.Status, error) 780 } 781 782 func (m *mockShard) compareAndSwapStatus(old, new string) (storagestate.Status, error) { 783 if m.compareAndSwapStatusFn == nil { 784 return storagestate.Status(new), nil 785 } 786 787 return m.compareAndSwapStatusFn(old, new) 788 } 789 790 type mockBatchIndexer struct { 791 sync.Mutex 792 addBatchFn func(id []uint64, vector [][]float32) error 793 vectors map[uint64][]float32 794 containsNodeFn func(id uint64) bool 795 deleteFn func(ids ...uint64) error 796 distancerProvider distancer.Provider 797 shouldCompress bool 798 threshold int 799 compressed bool 800 alreadyIndexed uint64 801 onCompressionTurnedOn func(func()) error 802 } 803 804 func (m *mockBatchIndexer) AddBatch(ctx context.Context, ids []uint64, vector [][]float32) (err error) { 805 m.Lock() 806 defer m.Unlock() 807 808 if m.addBatchFn != nil { 809 err = m.addBatchFn(ids, vector) 810 } 811 812 if m.vectors == nil { 813 m.vectors = make(map[uint64][]float32) 814 } 815 816 for i, id := range ids { 817 m.vectors[id] = vector[i] 818 } 819 820 return 821 } 822 823 func (m *mockBatchIndexer) SearchByVector(vector []float32, k int, allowList helpers.AllowList) ([]uint64, []float32, error) { 824 m.Lock() 825 defer m.Unlock() 826 827 results := newPqMaxPool(k).GetMax(k) 828 829 if m.DistancerProvider().Type() == "cosine-dot" { 830 vector = distancer.Normalize(vector) 831 } 832 833 for id, v := range m.vectors { 834 // skip filtered data 835 if allowList != nil && allowList.Contains(id) { 836 continue 837 } 838 839 if m.DistancerProvider().Type() == "cosine-dot" { 840 v = distancer.Normalize(v) 841 } 842 843 dist, _, err := m.DistanceBetweenVectors(vector, v) 844 if err != nil { 845 return nil, nil, err 846 } 847 848 if results.Len() < k || dist < results.Top().Dist { 849 results.Insert(id, dist) 850 for results.Len() > k { 851 results.Pop() 852 } 853 } 854 } 855 var ids []uint64 856 var distances []float32 857 858 for i := k - 1; i >= 0; i-- { 859 if results.Len() == 0 { 860 break 861 } 862 element := results.Pop() 863 ids = append(ids, element.ID) 864 distances = append(distances, element.Dist) 865 } 866 867 return ids, distances, nil 868 } 869 870 func (m *mockBatchIndexer) SearchByVectorDistance(vector []float32, maxDistance float32, maxLimit int64, allowList helpers.AllowList) ([]uint64, []float32, error) { 871 m.Lock() 872 defer m.Unlock() 873 874 results := newPqMaxPool(int(maxLimit)).GetMax(int(maxLimit)) 875 876 if m.DistancerProvider().Type() == "cosine-dot" { 877 vector = distancer.Normalize(vector) 878 } 879 880 for id, v := range m.vectors { 881 // skip filtered data 882 if allowList != nil && allowList.Contains(id) { 883 continue 884 } 885 886 if m.DistancerProvider().Type() == "cosine-dot" { 887 v = distancer.Normalize(v) 888 } 889 890 dist, _, err := m.DistanceBetweenVectors(vector, v) 891 if err != nil { 892 return nil, nil, err 893 } 894 895 if dist > maxDistance { 896 continue 897 } 898 899 if results.Len() < int(maxLimit) || dist < results.Top().Dist { 900 results.Insert(id, dist) 901 for results.Len() > int(maxLimit) { 902 results.Pop() 903 } 904 } 905 } 906 var ids []uint64 907 var distances []float32 908 909 for i := maxLimit - 1; i >= 0; i-- { 910 if results.Len() == 0 { 911 break 912 } 913 element := results.Pop() 914 ids = append(ids, element.ID) 915 distances = append(distances, element.Dist) 916 } 917 918 return ids, distances, nil 919 } 920 921 func (m *mockBatchIndexer) DistanceBetweenVectors(x, y []float32) (float32, bool, error) { 922 res := float32(0) 923 for i := range x { 924 diff := x[i] - y[i] 925 res += diff * diff 926 } 927 return res, true, nil 928 } 929 930 func (m *mockBatchIndexer) ContainsNode(id uint64) bool { 931 m.Lock() 932 defer m.Unlock() 933 if m.containsNodeFn != nil { 934 return m.containsNodeFn(id) 935 } 936 937 _, ok := m.vectors[id] 938 return ok 939 } 940 941 func (m *mockBatchIndexer) Delete(ids ...uint64) error { 942 m.Lock() 943 defer m.Unlock() 944 if m.deleteFn != nil { 945 return m.deleteFn(ids...) 946 } 947 948 for _, id := range ids { 949 delete(m.vectors, id) 950 } 951 952 return nil 953 } 954 955 func (m *mockBatchIndexer) IDs() []uint64 { 956 m.Lock() 957 defer m.Unlock() 958 959 ids := make([]uint64, 0, len(m.vectors)) 960 for id := range m.vectors { 961 ids = append(ids, id) 962 } 963 964 sort.Slice(ids, func(i, j int) bool { 965 return ids[i] < ids[j] 966 }) 967 968 return ids 969 } 970 971 func (m *mockBatchIndexer) DistancerProvider() distancer.Provider { 972 if m.distancerProvider == nil { 973 return distancer.NewL2SquaredProvider() 974 } 975 976 return m.distancerProvider 977 } 978 979 func (m *mockBatchIndexer) ShouldCompress() (bool, int) { 980 return m.shouldCompress, m.threshold 981 } 982 983 func (m *mockBatchIndexer) Compressed() bool { 984 return m.compressed 985 } 986 987 func (m *mockBatchIndexer) AlreadyIndexed() uint64 { 988 return m.alreadyIndexed 989 } 990 991 func (m *mockBatchIndexer) TurnOnCompression(callback func()) error { 992 if m.onCompressionTurnedOn != nil { 993 return m.onCompressionTurnedOn(callback) 994 } 995 996 return nil 997 } 998 999 func newDummyStore(t *testing.T) *lsmkv.Store { 1000 logger, _ := test.NewNullLogger() 1001 storeDir := t.TempDir() 1002 store, err := lsmkv.New(storeDir, storeDir, logger, nil, 1003 cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop()) 1004 require.Nil(t, err) 1005 return store 1006 }