github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/aggregator/vector_search.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package aggregator 13 14 import ( 15 "context" 16 "fmt" 17 18 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 19 "github.com/weaviate/weaviate/adapters/repos/db/inverted" 20 "github.com/weaviate/weaviate/entities/additional" 21 "github.com/weaviate/weaviate/entities/storobj" 22 ) 23 24 func (a *Aggregator) vectorSearch(allow helpers.AllowList, vec []float32) ([]uint64, []float32, error) { 25 if a.params.ObjectLimit != nil { 26 return a.searchByVector(vec, a.params.ObjectLimit, allow) 27 } 28 29 return a.searchByVectorDistance(vec, allow) 30 } 31 32 func (a *Aggregator) searchByVector(searchVector []float32, limit *int, ids helpers.AllowList) ([]uint64, []float32, error) { 33 idsFound, dists, err := a.vectorIndex.SearchByVector(searchVector, *limit, ids) 34 if err != nil { 35 return idsFound, nil, err 36 } 37 38 if a.params.Certainty > 0 { 39 targetDist := float32(1-a.params.Certainty) * 2 40 41 i := 0 42 for _, dist := range dists { 43 if dist > targetDist { 44 break 45 } 46 i++ 47 } 48 49 return idsFound[:i], dists, nil 50 51 } 52 return idsFound, dists, nil 53 } 54 55 func (a *Aggregator) searchByVectorDistance(searchVector []float32, ids helpers.AllowList) ([]uint64, []float32, error) { 56 if a.params.Certainty <= 0 { 57 return nil, nil, fmt.Errorf("must provide certainty or objectLimit with vector search") 58 } 59 60 targetDist := float32(1-a.params.Certainty) * 2 61 idsFound, dists, err := a.vectorIndex.SearchByVectorDistance(searchVector, targetDist, -1, ids) 62 if err != nil { 63 return nil, nil, fmt.Errorf("aggregate search by vector: %w", err) 64 } 65 66 return idsFound, dists, nil 67 } 68 69 func (a *Aggregator) objectVectorSearch(searchVector []float32, 70 allowList helpers.AllowList, 71 ) ([]*storobj.Object, []float32, error) { 72 ids, dists, err := a.vectorSearch(allowList, searchVector) 73 if err != nil { 74 return nil, nil, err 75 } 76 77 bucket := a.store.Bucket(helpers.ObjectsBucketLSM) 78 objs, err := storobj.ObjectsByDocID(bucket, ids, additional.Properties{}) 79 if err != nil { 80 return nil, nil, fmt.Errorf("get objects by doc id: %w", err) 81 } 82 return objs, dists, nil 83 } 84 85 func (a *Aggregator) buildAllowList(ctx context.Context) (helpers.AllowList, error) { 86 var ( 87 allow helpers.AllowList 88 err error 89 ) 90 91 if a.params.Filters != nil { 92 s := a.getSchema.GetSchemaSkipAuth() 93 allow, err = inverted.NewSearcher(a.logger, a.store, s, nil, 94 a.classSearcher, a.stopwords, a.shardVersion, a.isFallbackToSearchable, 95 a.tenant, a.nestedCrossRefLimit, a.bitmapFactory). 96 DocIDs(ctx, a.params.Filters, additional.Properties{}, 97 a.params.ClassName) 98 if err != nil { 99 return nil, fmt.Errorf("retrieve doc IDs from searcher: %w", err) 100 } 101 } 102 103 return allow, nil 104 }