github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/search.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hnsw 13 14 import ( 15 "context" 16 "fmt" 17 "math" 18 "sync/atomic" 19 20 "github.com/pkg/errors" 21 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 22 "github.com/weaviate/weaviate/adapters/repos/db/priorityqueue" 23 "github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" 24 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 25 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/visited" 26 "github.com/weaviate/weaviate/entities/storobj" 27 "github.com/weaviate/weaviate/usecases/floatcomp" 28 ) 29 30 func (h *hnsw) searchTimeEF(k int) int { 31 // load atomically, so we can get away with concurrent updates of the 32 // userconfig without having to set a lock each time we try to read - which 33 // can be so common that it would cause considerable overhead 34 ef := int(atomic.LoadInt64(&h.ef)) 35 if ef < 1 { 36 return h.autoEfFromK(k) 37 } 38 39 if ef < k { 40 ef = k 41 } 42 43 return ef 44 } 45 46 func (h *hnsw) autoEfFromK(k int) int { 47 factor := int(atomic.LoadInt64(&h.efFactor)) 48 min := int(atomic.LoadInt64(&h.efMin)) 49 max := int(atomic.LoadInt64(&h.efMax)) 50 51 ef := k * factor 52 if ef > max { 53 ef = max 54 } else if ef < min { 55 ef = min 56 } 57 if k > ef { 58 ef = k // otherwise results will get cut off early 59 } 60 61 return ef 62 } 63 64 func (h *hnsw) SearchByVector(vector []float32, k int, allowList helpers.AllowList) ([]uint64, []float32, error) { 65 h.compressActionLock.RLock() 66 defer h.compressActionLock.RUnlock() 67 68 vector = h.normalizeVec(vector) 69 flatSearchCutoff := int(atomic.LoadInt64(&h.flatSearchCutoff)) 70 if allowList != nil && !h.forbidFlat && allowList.Len() < flatSearchCutoff { 71 return h.flatSearch(vector, k, allowList) 72 } 73 return h.knnSearchByVector(vector, k, h.searchTimeEF(k), allowList) 74 } 75 76 // SearchByVectorDistance wraps SearchByVector, and calls it recursively until 77 // the search results contain all vector within the threshold specified by the 78 // target distance. 79 // 80 // The maxLimit param will place an upper bound on the number of search results 81 // returned. This is used in situations where the results of the method are all 82 // eventually turned into objects, for example, a Get query. If the caller just 83 // needs ids for sake of something like aggregation, a maxLimit of -1 can be 84 // passed in to truly obtain all results from the vector index. 85 func (h *hnsw) SearchByVectorDistance(vector []float32, targetDistance float32, maxLimit int64, 86 allowList helpers.AllowList, 87 ) ([]uint64, []float32, error) { 88 var ( 89 searchParams = newSearchByDistParams(maxLimit) 90 91 resultIDs []uint64 92 resultDist []float32 93 ) 94 95 recursiveSearch := func() (bool, error) { 96 shouldContinue := false 97 98 ids, dist, err := h.SearchByVector(vector, searchParams.totalLimit, allowList) 99 if err != nil { 100 return false, errors.Wrap(err, "vector search") 101 } 102 103 // ensures the indexers aren't out of range 104 offsetCap := searchParams.offsetCapacity(ids) 105 totalLimitCap := searchParams.totalLimitCapacity(ids) 106 107 ids, dist = ids[offsetCap:totalLimitCap], dist[offsetCap:totalLimitCap] 108 109 if len(ids) == 0 { 110 return false, nil 111 } 112 113 lastFound := dist[len(dist)-1] 114 shouldContinue = lastFound <= targetDistance 115 116 for i := range ids { 117 if aboveThresh := dist[i] <= targetDistance; aboveThresh || 118 floatcomp.InDelta(float64(dist[i]), float64(targetDistance), 1e-6) { 119 resultIDs = append(resultIDs, ids[i]) 120 resultDist = append(resultDist, dist[i]) 121 } else { 122 // as soon as we encounter a certainty which 123 // is below threshold, we can stop searching 124 break 125 } 126 } 127 128 return shouldContinue, nil 129 } 130 131 shouldContinue, err := recursiveSearch() 132 if err != nil { 133 return nil, nil, err 134 } 135 136 for shouldContinue { 137 searchParams.iterate() 138 if searchParams.maxLimitReached() { 139 h.logger. 140 WithField("action", "unlimited_vector_search"). 141 Warnf("maximum search limit of %d results has been reached", 142 searchParams.maximumSearchLimit) 143 break 144 } 145 146 shouldContinue, err = recursiveSearch() 147 if err != nil { 148 return nil, nil, err 149 } 150 } 151 152 return resultIDs, resultDist, nil 153 } 154 155 func (h *hnsw) shouldRescore() bool { 156 return h.compressed.Load() && !h.doNotRescore 157 } 158 159 func (h *hnsw) searchLayerByVector(queryVector []float32, 160 entrypoints *priorityqueue.Queue[any], ef int, level int, 161 allowList helpers.AllowList, 162 ) (*priorityqueue.Queue[any], error, 163 ) { 164 var compressorDistancer compressionhelpers.CompressorDistancer 165 if h.compressed.Load() { 166 var returnFn compressionhelpers.ReturnDistancerFn 167 compressorDistancer, returnFn = h.compressor.NewDistancer(queryVector) 168 defer returnFn() 169 } 170 return h.searchLayerByVectorWithDistancer(queryVector, entrypoints, ef, level, allowList, compressorDistancer) 171 } 172 173 func (h *hnsw) searchLayerByVectorWithDistancer(queryVector []float32, 174 entrypoints *priorityqueue.Queue[any], ef int, level int, 175 allowList helpers.AllowList, compressorDistancer compressionhelpers.CompressorDistancer) (*priorityqueue.Queue[any], error, 176 ) { 177 h.pools.visitedListsLock.RLock() 178 visited := h.pools.visitedLists.Borrow() 179 h.pools.visitedListsLock.RUnlock() 180 181 candidates := h.pools.pqCandidates.GetMin(ef) 182 results := h.pools.pqResults.GetMax(ef) 183 var floatDistancer distancer.Distancer 184 if h.compressed.Load() { 185 if compressorDistancer == nil { 186 var returnFn compressionhelpers.ReturnDistancerFn 187 compressorDistancer, returnFn = h.compressor.NewDistancer(queryVector) 188 defer returnFn() 189 } 190 } else { 191 floatDistancer = h.distancerProvider.New(queryVector) 192 } 193 194 h.insertViableEntrypointsAsCandidatesAndResults(entrypoints, candidates, 195 results, level, visited, allowList) 196 197 var worstResultDistance float32 198 var err error 199 if h.compressed.Load() { 200 worstResultDistance, err = h.currentWorstResultDistanceToByte(results, compressorDistancer) 201 } else { 202 worstResultDistance, err = h.currentWorstResultDistanceToFloat(results, floatDistancer) 203 } 204 if err != nil { 205 return nil, errors.Wrapf(err, "calculate distance of current last result") 206 } 207 connectionsReusable := make([]uint64, h.maximumConnectionsLayerZero) 208 209 for candidates.Len() > 0 { 210 var dist float32 211 candidate := candidates.Pop() 212 dist = candidate.Dist 213 214 if dist > worstResultDistance && results.Len() >= ef { 215 break 216 } 217 218 h.shardedNodeLocks.RLock(candidate.ID) 219 candidateNode := h.nodes[candidate.ID] 220 h.shardedNodeLocks.RUnlock(candidate.ID) 221 222 if candidateNode == nil { 223 // could have been a node that already had a tombstone attached and was 224 // just cleaned up while we were waiting for a read lock 225 continue 226 } 227 228 candidateNode.Lock() 229 if candidateNode.level < level { 230 // a node level could have been downgraded as part of a delete-reassign, 231 // but the connections pointing to it not yet cleaned up. In this case 232 // the node doesn't have any outgoing connections at this level and we 233 // must discard it. 234 candidateNode.Unlock() 235 continue 236 } 237 238 if len(candidateNode.connections[level]) > h.maximumConnectionsLayerZero { 239 // How is it possible that we could ever have more connections than the 240 // allowed maximum? It is not anymore, but there was a bug that allowed 241 // this to happen in versions prior to v1.12.0: 242 // https://github.com/weaviate/weaviate/issues/1868 243 // 244 // As a result the length of this slice is entirely unpredictable and we 245 // can no longer retrieve it from the pool. Instead we need to fallback 246 // to allocating a new slice. 247 // 248 // This was discovered as part of 249 // https://github.com/weaviate/weaviate/issues/1897 250 connectionsReusable = make([]uint64, len(candidateNode.connections[level])) 251 } else { 252 connectionsReusable = connectionsReusable[:len(candidateNode.connections[level])] 253 } 254 255 copy(connectionsReusable, candidateNode.connections[level]) 256 candidateNode.Unlock() 257 258 for _, neighborID := range connectionsReusable { 259 260 if ok := visited.Visited(neighborID); ok { 261 // skip if we've already visited this neighbor 262 continue 263 } 264 265 // make sure we never visit this neighbor again 266 visited.Visit(neighborID) 267 var distance float32 268 var ok bool 269 var err error 270 if h.compressed.Load() { 271 distance, ok, err = compressorDistancer.DistanceToNode(neighborID) 272 } else { 273 distance, ok, err = h.distanceToFloatNode(floatDistancer, neighborID) 274 } 275 if err != nil { 276 var e storobj.ErrNotFound 277 if errors.As(err, &e) { 278 h.handleDeletedNode(e.DocID) 279 continue 280 } else { 281 if err != nil { 282 return nil, errors.Wrap(err, "calculate distance between candidate and query") 283 } 284 } 285 } 286 287 if !ok { 288 // node was deleted in the underlying object store 289 continue 290 } 291 292 if distance < worstResultDistance || results.Len() < ef { 293 candidates.Insert(neighborID, distance) 294 if level == 0 && allowList != nil { 295 // we are on the lowest level containing the actual candidates and we 296 // have an allow list (i.e. the user has probably set some sort of a 297 // filter restricting this search further. As a result we have to 298 // ignore items not on the list 299 if !allowList.Contains(neighborID) { 300 continue 301 } 302 } 303 304 if h.hasTombstone(neighborID) { 305 continue 306 } 307 308 results.Insert(neighborID, distance) 309 310 if h.compressed.Load() { 311 h.compressor.Prefetch(candidates.Top().ID) 312 } else { 313 h.cache.Prefetch(candidates.Top().ID) 314 } 315 316 // +1 because we have added one node size calculating the len 317 if results.Len() > ef { 318 results.Pop() 319 } 320 321 if results.Len() > 0 { 322 worstResultDistance = results.Top().Dist 323 } 324 } 325 } 326 } 327 328 h.pools.pqCandidates.Put(candidates) 329 330 h.pools.visitedListsLock.RLock() 331 h.pools.visitedLists.Return(visited) 332 h.pools.visitedListsLock.RUnlock() 333 334 return results, nil 335 } 336 337 func (h *hnsw) insertViableEntrypointsAsCandidatesAndResults( 338 entrypoints, candidates, results *priorityqueue.Queue[any], level int, 339 visitedList visited.ListSet, allowList helpers.AllowList, 340 ) { 341 for entrypoints.Len() > 0 { 342 ep := entrypoints.Pop() 343 visitedList.Visit(ep.ID) 344 candidates.Insert(ep.ID, ep.Dist) 345 if level == 0 && allowList != nil { 346 // we are on the lowest level containing the actual candidates and we 347 // have an allow list (i.e. the user has probably set some sort of a 348 // filter restricting this search further. As a result we have to 349 // ignore items not on the list 350 if !allowList.Contains(ep.ID) { 351 continue 352 } 353 } 354 355 if h.hasTombstone(ep.ID) { 356 continue 357 } 358 359 results.Insert(ep.ID, ep.Dist) 360 } 361 } 362 363 func (h *hnsw) currentWorstResultDistanceToFloat(results *priorityqueue.Queue[any], 364 distancer distancer.Distancer, 365 ) (float32, error) { 366 if results.Len() > 0 { 367 id := results.Top().ID 368 369 d, ok, err := h.distanceToFloatNode(distancer, id) 370 if err != nil { 371 var e storobj.ErrNotFound 372 if errors.As(err, &e) { 373 h.handleDeletedNode(e.DocID) 374 } else { 375 if err != nil { 376 return 0, errors.Wrap(err, "calculated distance between worst result and query") 377 } 378 } 379 } 380 381 if !ok { 382 return math.MaxFloat32, nil 383 } 384 return d, nil 385 } else { 386 // if the entrypoint (which we received from a higher layer doesn't match 387 // the allow List the result list is empty. In this case we can just set 388 // the worstDistance to an arbitrarily large number, so that any 389 // (allowed) candidate will have a lower distance in comparison 390 return math.MaxFloat32, nil 391 } 392 } 393 394 func (h *hnsw) currentWorstResultDistanceToByte(results *priorityqueue.Queue[any], 395 distancer compressionhelpers.CompressorDistancer, 396 ) (float32, error) { 397 if results.Len() > 0 { 398 item := results.Top() 399 if item.Dist != 0 { 400 return item.Dist, nil 401 } 402 id := item.ID 403 d, ok, err := distancer.DistanceToNode(id) 404 if err != nil { 405 return 0, errors.Wrap(err, 406 "calculated distance between worst result and query") 407 } 408 409 if !ok { 410 return math.MaxFloat32, nil 411 } 412 return d, nil 413 } else { 414 // if the entrypoint (which we received from a higher layer doesn't match 415 // the allow List the result list is empty. In this case we can just set 416 // the worstDistance to an arbitrarily large number, so that any 417 // (allowed) candidate will have a lower distance in comparison 418 return math.MaxFloat32, nil 419 } 420 } 421 422 func (h *hnsw) distanceFromBytesToFloatNode(concreteDistancer compressionhelpers.CompressorDistancer, nodeID uint64) (float32, bool, error) { 423 slice := h.pools.tempVectors.Get(int(h.dims)) 424 defer h.pools.tempVectors.Put(slice) 425 vec, err := h.TempVectorForIDThunk(context.Background(), nodeID, slice) 426 if err != nil { 427 var e storobj.ErrNotFound 428 if errors.As(err, &e) { 429 h.handleDeletedNode(e.DocID) 430 return 0, false, nil 431 } else { 432 // not a typed error, we can recover from, return with err 433 return 0, false, errors.Wrapf(err, "get vector of docID %d", nodeID) 434 } 435 } 436 vec = h.normalizeVec(vec) 437 return concreteDistancer.DistanceToFloat(vec) 438 } 439 440 func (h *hnsw) distanceToFloatNode(distancer distancer.Distancer, 441 nodeID uint64, 442 ) (float32, bool, error) { 443 candidateVec, err := h.vectorForID(context.Background(), nodeID) 444 if err != nil { 445 return 0, false, err 446 } 447 448 dist, _, err := distancer.Distance(candidateVec) 449 if err != nil { 450 return 0, false, errors.Wrap(err, "calculate distance between candidate and query") 451 } 452 453 return dist, true, nil 454 } 455 456 // the underlying object seems to have been deleted, to recover from 457 // this situation let's add a tombstone to the deleted object, so it 458 // will be cleaned up and skip this candidate in the current search 459 func (h *hnsw) handleDeletedNode(docID uint64) { 460 if h.hasTombstone(docID) { 461 // nothing to do, this node already has a tombstone, it will be cleaned up 462 // in the next deletion cycle 463 return 464 } 465 466 h.addTombstone(docID) 467 h.logger.WithField("action", "attach_tombstone_to_deleted_node"). 468 WithField("node_id", docID). 469 Infof("found a deleted node (%d) without a tombstone, "+ 470 "tombstone was added", docID) 471 } 472 473 func (h *hnsw) knnSearchByVector(searchVec []float32, k int, 474 ef int, allowList helpers.AllowList, 475 ) ([]uint64, []float32, error) { 476 if h.isEmpty() { 477 return nil, nil, nil 478 } 479 480 if k < 0 { 481 return nil, nil, fmt.Errorf("k must be greater than zero") 482 } 483 484 h.RLock() 485 entryPointID := h.entryPointID 486 maxLayer := h.currentMaximumLayer 487 h.RUnlock() 488 489 entryPointDistance, ok, err := h.distBetweenNodeAndVec(entryPointID, searchVec) 490 if err != nil { 491 return nil, nil, errors.Wrap(err, "knn search: distance between entrypoint and query node") 492 } 493 494 if !ok { 495 return nil, nil, fmt.Errorf("entrypoint was deleted in the object store, " + 496 "it has been flagged for cleanup and should be fixed in the next cleanup cycle") 497 } 498 499 var compressorDistancer compressionhelpers.CompressorDistancer 500 if h.compressed.Load() { 501 var returnFn compressionhelpers.ReturnDistancerFn 502 compressorDistancer, returnFn = h.compressor.NewDistancer(searchVec) 503 defer returnFn() 504 } 505 // stop at layer 1, not 0! 506 for level := maxLayer; level >= 1; level-- { 507 eps := priorityqueue.NewMin[any](10) 508 eps.Insert(entryPointID, entryPointDistance) 509 510 res, err := h.searchLayerByVectorWithDistancer(searchVec, eps, 1, level, nil, compressorDistancer) 511 if err != nil { 512 return nil, nil, errors.Wrapf(err, "knn search: search layer at level %d", level) 513 } 514 515 // There might be situations where we did not find a better entrypoint at 516 // that particular level, so instead we're keeping whatever entrypoint we 517 // had before (i.e. either from a previous level or even the main 518 // entrypoint) 519 // 520 // If we do, however, have results, any candidate that's not nil (not 521 // deleted), and not under maintenance is a viable candidate 522 for res.Len() > 0 { 523 cand := res.Pop() 524 n := h.nodeByID(cand.ID) 525 if n == nil { 526 // we have found a node in results that is nil. This means it was 527 // deleted, but not cleaned up properly. Make sure to add a tombstone to 528 // this node, so it can be cleaned up in the next cycle. 529 if err := h.addTombstone(cand.ID); err != nil { 530 return nil, nil, err 531 } 532 533 // skip the nil node, as it does not make a valid entrypoint 534 continue 535 } 536 537 if !n.isUnderMaintenance() { 538 entryPointID = cand.ID 539 entryPointDistance = cand.Dist 540 break 541 } 542 543 // if we managed to go through the loop without finding a single 544 // suitable node, we simply stick with the original, i.e. the global 545 // entrypoint 546 } 547 548 h.pools.pqResults.Put(res) 549 } 550 551 eps := priorityqueue.NewMin[any](10) 552 eps.Insert(entryPointID, entryPointDistance) 553 res, err := h.searchLayerByVectorWithDistancer(searchVec, eps, ef, 0, allowList, compressorDistancer) 554 if err != nil { 555 return nil, nil, errors.Wrapf(err, "knn search: search layer at level %d", 0) 556 } 557 558 if h.shouldRescore() { 559 ids := make([]uint64, res.Len()) 560 i := len(ids) - 1 561 for res.Len() > 0 { 562 res := res.Pop() 563 ids[i] = res.ID 564 i-- 565 } 566 res.Reset() 567 for _, id := range ids { 568 dist, _, _ := h.distanceFromBytesToFloatNode(compressorDistancer, id) 569 res.Insert(id, dist) 570 if res.Len() > ef { 571 res.Pop() 572 } 573 } 574 575 } 576 577 for res.Len() > k { 578 res.Pop() 579 } 580 581 ids := make([]uint64, res.Len()) 582 dists := make([]float32, res.Len()) 583 584 // results is ordered in reverse, we need to flip the order before presenting 585 // to the user! 586 i := len(ids) - 1 587 for res.Len() > 0 { 588 res := res.Pop() 589 ids[i] = res.ID 590 dists[i] = res.Dist 591 i-- 592 } 593 h.pools.pqResults.Put(res) 594 return ids, dists, nil 595 } 596 597 func newSearchByDistParams(maxLimit int64) *searchByDistParams { 598 initialOffset := 0 599 initialLimit := DefaultSearchByDistInitialLimit 600 601 return &searchByDistParams{ 602 offset: initialOffset, 603 limit: initialLimit, 604 totalLimit: initialOffset + initialLimit, 605 maximumSearchLimit: maxLimit, 606 } 607 } 608 609 const ( 610 // DefaultSearchByDistInitialLimit : 611 // the initial limit of 100 here is an 612 // arbitrary decision, and can be tuned 613 // as needed 614 DefaultSearchByDistInitialLimit = 100 615 616 // DefaultSearchByDistLimitMultiplier : 617 // the decision to increase the limit in 618 // multiples of 10 here is an arbitrary 619 // decision, and can be tuned as needed 620 DefaultSearchByDistLimitMultiplier = 10 621 ) 622 623 type searchByDistParams struct { 624 offset int 625 limit int 626 totalLimit int 627 maximumSearchLimit int64 628 } 629 630 func (params *searchByDistParams) offsetCapacity(ids []uint64) int { 631 var offsetCap int 632 if params.offset < len(ids) { 633 offsetCap = params.offset 634 } else { 635 offsetCap = len(ids) 636 } 637 638 return offsetCap 639 } 640 641 func (params *searchByDistParams) totalLimitCapacity(ids []uint64) int { 642 var totalLimitCap int 643 if params.totalLimit < len(ids) { 644 totalLimitCap = params.totalLimit 645 } else { 646 totalLimitCap = len(ids) 647 } 648 649 return totalLimitCap 650 } 651 652 func (params *searchByDistParams) iterate() { 653 params.offset = params.totalLimit 654 params.limit *= DefaultSearchByDistLimitMultiplier 655 params.totalLimit = params.offset + params.limit 656 } 657 658 func (params *searchByDistParams) maxLimitReached() bool { 659 if params.maximumSearchLimit < 0 { 660 return false 661 } 662 663 return int64(params.totalLimit) > params.maximumSearchLimit 664 }