github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/inverted/bm25_searcher.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package inverted 13 14 import ( 15 "context" 16 "encoding/binary" 17 "fmt" 18 "math" 19 "sort" 20 "strconv" 21 "strings" 22 "sync" 23 24 enterrors "github.com/weaviate/weaviate/entities/errors" 25 26 "github.com/pkg/errors" 27 "github.com/sirupsen/logrus" 28 "github.com/weaviate/sroar" 29 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 30 "github.com/weaviate/weaviate/adapters/repos/db/inverted/stopwords" 31 "github.com/weaviate/weaviate/adapters/repos/db/lsmkv" 32 "github.com/weaviate/weaviate/adapters/repos/db/priorityqueue" 33 "github.com/weaviate/weaviate/adapters/repos/db/propertyspecific" 34 "github.com/weaviate/weaviate/entities/inverted" 35 "github.com/weaviate/weaviate/entities/models" 36 "github.com/weaviate/weaviate/entities/schema" 37 "github.com/weaviate/weaviate/entities/searchparams" 38 "github.com/weaviate/weaviate/entities/storobj" 39 ) 40 41 type BM25Searcher struct { 42 config schema.BM25Config 43 store *lsmkv.Store 44 schema schema.Schema 45 classSearcher ClassSearcher // to allow recursive searches on ref-props 46 propIndices propertyspecific.Indices 47 propLenTracker propLengthRetriever 48 logger logrus.FieldLogger 49 shardVersion uint16 50 } 51 52 type propLengthRetriever interface { 53 PropertyMean(prop string) (float32, error) 54 } 55 56 func NewBM25Searcher(config schema.BM25Config, store *lsmkv.Store, 57 schema schema.Schema, propIndices propertyspecific.Indices, 58 classSearcher ClassSearcher, propLenTracker propLengthRetriever, 59 logger logrus.FieldLogger, shardVersion uint16, 60 ) *BM25Searcher { 61 return &BM25Searcher{ 62 config: config, 63 store: store, 64 schema: schema, 65 propIndices: propIndices, 66 classSearcher: classSearcher, 67 propLenTracker: propLenTracker, 68 logger: logger.WithField("action", "bm25_search"), 69 shardVersion: shardVersion, 70 } 71 } 72 73 func (b *BM25Searcher) BM25F(ctx context.Context, filterDocIds helpers.AllowList, 74 className schema.ClassName, limit int, keywordRanking searchparams.KeywordRanking, 75 ) ([]*storobj.Object, []float32, error) { 76 // WEAVIATE-471 - If a property is not searchable, return an error 77 for _, property := range keywordRanking.Properties { 78 if !PropertyHasSearchableIndex(b.schema.Objects, string(className), property) { 79 return nil, nil, inverted.NewMissingSearchableIndexError(property) 80 } 81 } 82 class, err := schema.GetClassByName(b.schema.Objects, string(className)) 83 if err != nil { 84 return nil, nil, err 85 } 86 87 objs, scores, err := b.wand(ctx, filterDocIds, class, keywordRanking, limit) 88 if err != nil { 89 return nil, nil, errors.Wrap(err, "wand") 90 } 91 92 return objs, scores, nil 93 } 94 95 func (b *BM25Searcher) GetPropertyLengthTracker() *JsonPropertyLengthTracker { 96 return b.propLenTracker.(*JsonPropertyLengthTracker) 97 } 98 99 func (b *BM25Searcher) wand( 100 ctx context.Context, filterDocIds helpers.AllowList, class *models.Class, params searchparams.KeywordRanking, limit int, 101 ) ([]*storobj.Object, []float32, error) { 102 N := float64(b.store.Bucket(helpers.ObjectsBucketLSM).Count()) 103 104 var stopWordDetector *stopwords.Detector 105 if class.InvertedIndexConfig != nil && class.InvertedIndexConfig.Stopwords != nil { 106 var err error 107 stopWordDetector, err = stopwords.NewDetectorFromConfig(*(class.InvertedIndexConfig.Stopwords)) 108 if err != nil { 109 return nil, nil, err 110 } 111 } 112 113 // There are currently cases, for different tokenization: 114 // word, lowercase, whitespace and field. 115 // Query is tokenized and respective properties are then searched for the search terms, 116 // results at the end are combined using WAND 117 118 queryTermsByTokenization := map[string][]string{} 119 duplicateBoostsByTokenization := map[string][]int{} 120 propNamesByTokenization := map[string][]string{} 121 propertyBoosts := make(map[string]float32, len(params.Properties)) 122 123 for _, tokenization := range helpers.Tokenizations { 124 queryTerms, dupBoosts := helpers.TokenizeAndCountDuplicates(tokenization, params.Query) 125 queryTermsByTokenization[tokenization] = queryTerms 126 duplicateBoostsByTokenization[tokenization] = dupBoosts 127 128 // stopword filtering for word tokenization 129 if tokenization == models.PropertyTokenizationWord { 130 queryTerms, dupBoosts = b.removeStopwordsFromQueryTerms(queryTermsByTokenization[tokenization], 131 duplicateBoostsByTokenization[tokenization], stopWordDetector) 132 queryTermsByTokenization[tokenization] = queryTerms 133 duplicateBoostsByTokenization[tokenization] = dupBoosts 134 } 135 136 propNamesByTokenization[tokenization] = make([]string, 0) 137 } 138 139 averagePropLength := 0. 140 for _, propertyWithBoost := range params.Properties { 141 property := propertyWithBoost 142 propBoost := 1 143 if strings.Contains(propertyWithBoost, "^") { 144 property = strings.Split(propertyWithBoost, "^")[0] 145 boostStr := strings.Split(propertyWithBoost, "^")[1] 146 propBoost, _ = strconv.Atoi(boostStr) 147 } 148 propertyBoosts[property] = float32(propBoost) 149 150 propMean, err := b.GetPropertyLengthTracker().PropertyMean(property) 151 if err != nil { 152 return nil, nil, err 153 } 154 averagePropLength += float64(propMean) 155 156 prop, err := schema.GetPropertyByName(class, property) 157 if err != nil { 158 return nil, nil, err 159 } 160 161 switch dt, _ := schema.AsPrimitive(prop.DataType); dt { 162 case schema.DataTypeText, schema.DataTypeTextArray: 163 if _, exists := propNamesByTokenization[prop.Tokenization]; !exists { 164 return nil, nil, fmt.Errorf("cannot handle tokenization '%v' of property '%s'", 165 prop.Tokenization, prop.Name) 166 } 167 propNamesByTokenization[prop.Tokenization] = append(propNamesByTokenization[prop.Tokenization], property) 168 default: 169 return nil, nil, fmt.Errorf("cannot handle datatype '%v' of property '%s'", dt, prop.Name) 170 } 171 } 172 173 averagePropLength = averagePropLength / float64(len(params.Properties)) 174 175 // 100 is a reasonable expected capacity for the total number of terms to query. 176 results := make(terms, 0, 100) 177 indices := make([]map[uint64]int, 0, 100) 178 179 eg := enterrors.NewErrorGroupWrapper(b.logger) 180 eg.SetLimit(_NUMCPU) 181 182 var resultsLock sync.Mutex 183 184 for _, tokenization := range helpers.Tokenizations { 185 propNames := propNamesByTokenization[tokenization] 186 if len(propNames) > 0 { 187 queryTerms, duplicateBoosts := helpers.TokenizeAndCountDuplicates(tokenization, params.Query) 188 189 // stopword filtering for word tokenization 190 if tokenization == models.PropertyTokenizationWord { 191 queryTerms, duplicateBoosts = b.removeStopwordsFromQueryTerms( 192 queryTerms, duplicateBoosts, stopWordDetector) 193 } 194 195 for i := range queryTerms { 196 j := i 197 198 eg.Go(func() (err error) { 199 termResult, docIndices, termErr := b.createTerm(N, filterDocIds, queryTerms[j], propNames, 200 propertyBoosts, duplicateBoosts[j], params.AdditionalExplanations) 201 if termErr != nil { 202 err = termErr 203 return 204 } 205 resultsLock.Lock() 206 results = append(results, termResult) 207 indices = append(indices, docIndices) 208 resultsLock.Unlock() 209 return 210 }, "query_term", queryTerms[j], "prop_names", propNames, "has_filter", filterDocIds != nil) 211 } 212 } 213 } 214 215 if err := eg.Wait(); err != nil { 216 return nil, nil, err 217 } 218 // all results. Sum up the length of the results from all terms to get an upper bound of how many results there are 219 if limit == 0 { 220 for _, ind := range indices { 221 limit += len(ind) 222 } 223 } 224 225 // the results are needed in the original order to be able to locate frequency/property length for the top-results 226 resultsOriginalOrder := make(terms, len(results)) 227 copy(resultsOriginalOrder, results) 228 229 topKHeap := b.getTopKHeap(limit, results, averagePropLength) 230 return b.getTopKObjects(topKHeap, resultsOriginalOrder, indices, params.AdditionalExplanations) 231 } 232 233 func (b *BM25Searcher) removeStopwordsFromQueryTerms(queryTerms []string, 234 duplicateBoost []int, detector *stopwords.Detector, 235 ) ([]string, []int) { 236 if detector == nil || len(queryTerms) == 0 { 237 return queryTerms, duplicateBoost 238 } 239 240 i := 0 241 WordLoop: 242 for { 243 if i == len(queryTerms) { 244 return queryTerms, duplicateBoost 245 } 246 queryTerm := queryTerms[i] 247 if detector.IsStopword(queryTerm) { 248 queryTerms[i] = queryTerms[len(queryTerms)-1] 249 queryTerms = queryTerms[:len(queryTerms)-1] 250 duplicateBoost[i] = duplicateBoost[len(duplicateBoost)-1] 251 duplicateBoost = duplicateBoost[:len(duplicateBoost)-1] 252 253 continue WordLoop 254 } 255 256 i++ 257 } 258 } 259 260 func (b *BM25Searcher) getTopKObjects(topKHeap *priorityqueue.Queue[any], 261 results terms, indices []map[uint64]int, additionalExplanations bool, 262 ) ([]*storobj.Object, []float32, error) { 263 objectsBucket := b.store.Bucket(helpers.ObjectsBucketLSM) 264 if objectsBucket == nil { 265 return nil, nil, errors.Errorf("objects bucket not found") 266 } 267 268 objects := make([]*storobj.Object, 0, topKHeap.Len()) 269 scores := make([]float32, 0, topKHeap.Len()) 270 271 buf := make([]byte, 8) 272 for topKHeap.Len() > 0 { 273 res := topKHeap.Pop() 274 binary.LittleEndian.PutUint64(buf, res.ID) 275 objectByte, err := objectsBucket.GetBySecondary(0, buf) 276 if err != nil { 277 return nil, nil, err 278 } 279 280 // If there is a crash and WAL recovery, the inverted index may have objects that are not in the objects bucket. 281 // This is an issue that needs to be fixed, but for now we need to reduce the huge amount of log messages that 282 // are generated by this issue. Logging the first time we encounter a missing object in a query still resulted 283 // in a huge amount of log messages and it will happen on all queries, so we not log at all for now. 284 // The user has already been alerted about ppossible data loss when the WAL recovery happened. 285 // TODO: consider deleting these entries from the inverted index and alerting the user 286 if len(objectByte) == 0 { 287 continue 288 } 289 290 obj, err := storobj.FromBinary(objectByte) 291 if err != nil { 292 return nil, nil, err 293 } 294 295 if additionalExplanations { 296 // add score explanation 297 if obj.AdditionalProperties() == nil { 298 obj.Object.Additional = make(map[string]interface{}) 299 } 300 for j, result := range results { 301 if termIndex, ok := indices[j][res.ID]; ok { 302 queryTerm := result.queryTerm 303 if len(result.data) <= termIndex { 304 b.logger.Warnf( 305 "Skipping object explanation in BM25: term index %v is out of range for query term %v, length %d, id %v", 306 termIndex, queryTerm, len(result.data), res.ID) 307 continue 308 } 309 obj.Object.Additional["BM25F_"+queryTerm+"_frequency"] = result.data[termIndex].frequency 310 obj.Object.Additional["BM25F_"+queryTerm+"_propLength"] = result.data[termIndex].propLength 311 } 312 } 313 } 314 objects = append(objects, obj) 315 scores = append(scores, res.Dist) 316 317 } 318 return objects, scores, nil 319 } 320 321 func (b *BM25Searcher) getTopKHeap(limit int, results terms, averagePropLength float64, 322 ) *priorityqueue.Queue[any] { 323 topKHeap := priorityqueue.NewMin[any](limit) 324 worstDist := float64(-10000) // tf score can be negative 325 sort.Sort(results) 326 for { 327 if results.completelyExhausted() || results.pivot(worstDist) { 328 return topKHeap 329 } 330 331 id, score := results.scoreNext(averagePropLength, b.config) 332 333 if topKHeap.Len() < limit || topKHeap.Top().Dist < float32(score) { 334 topKHeap.Insert(id, float32(score)) 335 for topKHeap.Len() > limit { 336 topKHeap.Pop() 337 } 338 // only update the worst distance when the queue is full, otherwise results can be missing if the first 339 // entry that is checked already has a very high score 340 if topKHeap.Len() >= limit { 341 worstDist = float64(topKHeap.Top().Dist) 342 } 343 } 344 } 345 } 346 347 func (b *BM25Searcher) createTerm(N float64, filterDocIds helpers.AllowList, query string, 348 propertyNames []string, propertyBoosts map[string]float32, duplicateTextBoost int, 349 additionalExplanations bool, 350 ) (term, map[uint64]int, error) { 351 termResult := term{queryTerm: query} 352 filteredDocIDs := sroar.NewBitmap() // to build the global n if there is a filter 353 354 allMsAndProps := make(AllMapPairsAndPropName, 0, len(propertyNames)) 355 for _, propName := range propertyNames { 356 357 bucket := b.store.Bucket(helpers.BucketSearchableFromPropNameLSM(propName)) 358 if bucket == nil { 359 return termResult, nil, fmt.Errorf("could not find bucket for property %v", propName) 360 } 361 preM, err := bucket.MapList([]byte(query)) 362 if err != nil { 363 return termResult, nil, err 364 } 365 366 var m []lsmkv.MapPair 367 if filterDocIds != nil { 368 m = make([]lsmkv.MapPair, 0, len(preM)) 369 for _, val := range preM { 370 docID := binary.BigEndian.Uint64(val.Key) 371 if filterDocIds.Contains(docID) { 372 m = append(m, val) 373 } else { 374 filteredDocIDs.Set(docID) 375 } 376 } 377 } else { 378 m = preM 379 } 380 if len(m) == 0 { 381 continue 382 } 383 384 allMsAndProps = append(allMsAndProps, MapPairsAndPropName{MapPairs: m, propname: propName}) 385 } 386 387 // sort ascending, this code has two effects 388 // 1) We can skip writing the indices from the last property to the map (see next comment). Therefore, having the 389 // biggest property at the end will save us most writes on average 390 // 2) For the first property all entries are new, and we can create the map with the respective size. When choosing 391 // the second-biggest entry as the first property we save additional allocations later 392 sort.Sort(allMsAndProps) 393 if len(allMsAndProps) > 2 { 394 allMsAndProps[len(allMsAndProps)-2], allMsAndProps[0] = allMsAndProps[0], allMsAndProps[len(allMsAndProps)-2] 395 } 396 397 var docMapPairs []docPointerWithScore = nil 398 var docMapPairsIndices map[uint64]int = nil 399 for i, mAndProps := range allMsAndProps { 400 m := mAndProps.MapPairs 401 propName := mAndProps.propname 402 403 // The indices are needed for two things: 404 // a) combining the results of different properties 405 // b) Retrieve additional information that helps to understand the results when debugging. The retrieval is done 406 // in a later step, after it is clear which objects are the most relevant 407 // 408 // When b) is not needed the results from the last property do not need to be added to the index-map as there 409 // won't be any follow-up combinations. 410 includeIndicesForLastElement := false 411 if additionalExplanations || i < len(allMsAndProps)-1 { 412 includeIndicesForLastElement = true 413 } 414 415 // only create maps/slices if we know how many entries there are 416 if docMapPairs == nil { 417 docMapPairs = make([]docPointerWithScore, 0, len(m)) 418 docMapPairsIndices = make(map[uint64]int, len(m)) 419 for k, val := range m { 420 if len(val.Value) < 8 { 421 b.logger.Warnf("Skipping pair in BM25: MapPair.Value should be 8 bytes long, but is %d.", len(val.Value)) 422 continue 423 } 424 freqBits := binary.LittleEndian.Uint32(val.Value[0:4]) 425 propLenBits := binary.LittleEndian.Uint32(val.Value[4:8]) 426 docMapPairs = append(docMapPairs, 427 docPointerWithScore{ 428 id: binary.BigEndian.Uint64(val.Key), 429 frequency: math.Float32frombits(freqBits) * propertyBoosts[propName], 430 propLength: math.Float32frombits(propLenBits), 431 }) 432 if includeIndicesForLastElement { 433 docMapPairsIndices[binary.BigEndian.Uint64(val.Key)] = k 434 } 435 } 436 } else { 437 for _, val := range m { 438 if len(val.Value) < 8 { 439 b.logger.Warnf("Skipping pair in BM25: MapPair.Value should be 8 bytes long, but is %d.", len(val.Value)) 440 continue 441 } 442 key := binary.BigEndian.Uint64(val.Key) 443 ind, ok := docMapPairsIndices[key] 444 freqBits := binary.LittleEndian.Uint32(val.Value[0:4]) 445 propLenBits := binary.LittleEndian.Uint32(val.Value[4:8]) 446 if ok { 447 if ind >= len(docMapPairs) { 448 // the index is not valid anymore, but the key is still in the map 449 b.logger.Warnf("Skipping pair in BM25: Index %d is out of range for key %d, length %d.", ind, key, len(docMapPairs)) 450 continue 451 } 452 if ind < len(docMapPairs) && docMapPairs[ind].id != key { 453 b.logger.Warnf("Skipping pair in BM25: id at %d in doc map pairs, %d, differs from current key, %d", ind, docMapPairs[ind].id, key) 454 continue 455 } 456 457 docMapPairs[ind].propLength += math.Float32frombits(propLenBits) 458 docMapPairs[ind].frequency += math.Float32frombits(freqBits) * propertyBoosts[propName] 459 } else { 460 docMapPairs = append(docMapPairs, 461 docPointerWithScore{ 462 id: binary.BigEndian.Uint64(val.Key), 463 frequency: math.Float32frombits(freqBits) * propertyBoosts[propName], 464 propLength: math.Float32frombits(propLenBits), 465 }) 466 if includeIndicesForLastElement { 467 docMapPairsIndices[binary.BigEndian.Uint64(val.Key)] = len(docMapPairs) - 1 // current last entry 468 } 469 } 470 } 471 } 472 } 473 if docMapPairs == nil { 474 termResult.exhausted = true 475 return termResult, docMapPairsIndices, nil 476 } 477 termResult.data = docMapPairs 478 479 n := float64(len(docMapPairs)) 480 if filterDocIds != nil { 481 n += float64(filteredDocIDs.GetCardinality()) 482 } 483 termResult.idf = math.Log(float64(1)+(N-n+0.5)/(n+0.5)) * float64(duplicateTextBoost) 484 485 // catch special case where there are no results and would panic termResult.data[0].id 486 // related to #4125 487 if len(termResult.data) == 0 { 488 termResult.posPointer = 0 489 termResult.idPointer = 0 490 termResult.exhausted = true 491 return termResult, docMapPairsIndices, nil 492 } 493 494 termResult.posPointer = 0 495 termResult.idPointer = termResult.data[0].id 496 return termResult, docMapPairsIndices, nil 497 } 498 499 type term struct { 500 // doubles as max impact (with tf=1, the max impact would be 1*idf), if there 501 // is a boost for a queryTerm, simply apply it here once 502 idf float64 503 504 idPointer uint64 505 posPointer uint64 506 data []docPointerWithScore 507 exhausted bool 508 queryTerm string 509 } 510 511 func (t *term) scoreAndAdvance(averagePropLength float64, config schema.BM25Config) (uint64, float64) { 512 id := t.idPointer 513 pair := t.data[t.posPointer] 514 freq := float64(pair.frequency) 515 tf := freq / (freq + config.K1*(1-config.B+config.B*float64(pair.propLength)/averagePropLength)) 516 517 // advance 518 t.posPointer++ 519 if t.posPointer >= uint64(len(t.data)) { 520 t.exhausted = true 521 } else { 522 t.idPointer = t.data[t.posPointer].id 523 } 524 525 return id, tf * t.idf 526 } 527 528 func (t *term) advanceAtLeast(minID uint64) { 529 for t.idPointer < minID { 530 t.posPointer++ 531 if t.posPointer >= uint64(len(t.data)) { 532 t.exhausted = true 533 return 534 } 535 t.idPointer = t.data[t.posPointer].id 536 } 537 } 538 539 type terms []term 540 541 func (t terms) completelyExhausted() bool { 542 for i := range t { 543 if !t[i].exhausted { 544 return false 545 } 546 } 547 return true 548 } 549 550 func (t terms) pivot(minScore float64) bool { 551 minID, pivotPoint, abort := t.findMinID(minScore) 552 if abort { 553 return true 554 } 555 if pivotPoint == 0 { 556 return false 557 } 558 559 t.advanceAllAtLeast(minID) 560 sort.Sort(t) 561 return false 562 } 563 564 func (t terms) advanceAllAtLeast(minID uint64) { 565 for i := range t { 566 t[i].advanceAtLeast(minID) 567 } 568 } 569 570 func (t terms) findMinID(minScore float64) (uint64, int, bool) { 571 cumScore := float64(0) 572 573 for i, term := range t { 574 if term.exhausted { 575 continue 576 } 577 cumScore += term.idf 578 if cumScore >= minScore { 579 return term.idPointer, i, false 580 } 581 } 582 583 return 0, 0, true 584 } 585 586 func (t terms) findFirstNonExhausted() (int, bool) { 587 for i := range t { 588 if !t[i].exhausted { 589 return i, true 590 } 591 } 592 593 return -1, false 594 } 595 596 func (t terms) scoreNext(averagePropLength float64, config schema.BM25Config) (uint64, float64) { 597 pos, ok := t.findFirstNonExhausted() 598 if !ok { 599 // done, nothing left to score 600 return 0, 0 601 } 602 603 id := t[pos].idPointer 604 var cumScore float64 605 for i := pos; i < len(t); i++ { 606 if t[i].idPointer != id || t[i].exhausted { 607 continue 608 } 609 _, score := t[i].scoreAndAdvance(averagePropLength, config) 610 cumScore += score 611 } 612 613 sort.Sort(t) // pointer was advanced in scoreAndAdvance 614 615 return id, cumScore 616 } 617 618 // provide sort interface 619 func (t terms) Len() int { 620 return len(t) 621 } 622 623 func (t terms) Less(i, j int) bool { 624 return t[i].idPointer < t[j].idPointer 625 } 626 627 func (t terms) Swap(i, j int) { 628 t[i], t[j] = t[j], t[i] 629 } 630 631 type MapPairsAndPropName struct { 632 propname string 633 MapPairs []lsmkv.MapPair 634 } 635 636 type AllMapPairsAndPropName []MapPairsAndPropName 637 638 // provide sort interface 639 func (m AllMapPairsAndPropName) Len() int { 640 return len(m) 641 } 642 643 func (m AllMapPairsAndPropName) Less(i, j int) bool { 644 return len(m[i].MapPairs) < len(m[j].MapPairs) 645 } 646 647 func (m AllMapPairsAndPropName) Swap(i, j int) { 648 m[i], m[j] = m[j], m[i] 649 } 650 651 func PropertyHasSearchableIndex(schemaDefinition *models.Schema, className, tentativePropertyName string) bool { 652 propertyName := strings.Split(tentativePropertyName, "^")[0] 653 c, err := schema.GetClassByName(schemaDefinition, string(className)) 654 if err != nil { 655 return false 656 } 657 p, err := schema.GetPropertyByName(c, propertyName) 658 if err != nil { 659 return false 660 } 661 return HasSearchableIndex(p) 662 }