github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/searcher.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hybrid 13 14 import ( 15 "context" 16 "fmt" 17 18 "github.com/sirupsen/logrus" 19 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" 20 "github.com/weaviate/weaviate/entities/additional" 21 "github.com/weaviate/weaviate/entities/autocut" 22 "github.com/weaviate/weaviate/entities/schema" 23 "github.com/weaviate/weaviate/entities/search" 24 "github.com/weaviate/weaviate/entities/searchparams" 25 "github.com/weaviate/weaviate/entities/storobj" 26 uc "github.com/weaviate/weaviate/usecases/schema" 27 ) 28 29 const DefaultLimit = 100 30 31 type Params struct { 32 *searchparams.HybridSearch 33 Keyword *searchparams.KeywordRanking 34 Class string 35 Autocut int 36 } 37 38 // Result facilitates the pairing of a search result with its internal doc id. 39 // 40 // This type is key in generalising hybrid search across different use cases. 41 // Some use cases require a full search result (Get{} queries) and others need 42 // only a doc id (Aggregate{}) which the search.Result type does not contain. 43 // It does now 44 45 type Results []*search.Result 46 47 // sparseSearchFunc is the signature of a closure which performs sparse search. 48 // Any package which wishes use hybrid search must provide this. The weights are 49 // used in calculating the final scores of the result set. 50 type sparseSearchFunc func() (results []*storobj.Object, weights []float32, err error) 51 52 // denseSearchFunc is the signature of a closure which performs dense search. 53 // A search vector argument is required to pass along to the vector index. 54 // Any package which wishes use hybrid search must provide this The weights are 55 // used in calculating the final scores of the result set. 56 type denseSearchFunc func(searchVector []float32) (results []*storobj.Object, weights []float32, err error) 57 58 // postProcFunc takes the results of the hybrid search and applies some transformation. 59 // This is optionally provided, and allows the caller to somehow change the nature of 60 // the result set. For example, Get{} queries sometimes require resolving references, 61 // which is implemented by doing the reference resolution within a postProcFunc closure. 62 type postProcFunc func(hybridResults []*search.Result) (postProcResults []search.Result, err error) 63 64 type modulesProvider interface { 65 VectorFromInput(ctx context.Context, 66 className, input, targetVector string) ([]float32, error) 67 } 68 69 type targetVectorParamHelper interface { 70 GetTargetVectorOrDefault(sch schema.Schema, className, targetVector string) (string, error) 71 } 72 73 // Search executes sparse and dense searches and combines the result sets using Reciprocal Rank Fusion 74 func Search(ctx context.Context, params *Params, logger logrus.FieldLogger, sparseSearch sparseSearchFunc, 75 denseSearch denseSearchFunc, postProc postProcFunc, modules modulesProvider, 76 schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper, 77 ) ([]*search.Result, error) { 78 var ( 79 found [][]*search.Result 80 weights []float64 81 names []string 82 ) 83 84 if params.Query != "" { 85 alpha := params.Alpha 86 87 if alpha < 1 { 88 res, err := processSparseSearch(sparseSearch()) 89 if err != nil { 90 return nil, err 91 } 92 93 found = append(found, res) 94 weights = append(weights, 1-alpha) 95 names = append(names, "keyword") 96 } 97 98 if alpha > 0 { 99 res, err := processDenseSearch(ctx, denseSearch, params, modules, schemaGetter, targetVectorParamHelper) 100 if err != nil { 101 return nil, err 102 } 103 104 found = append(found, res) 105 weights = append(weights, alpha) 106 names = append(names, "vector") 107 } 108 } else if params.Vector != nil { 109 // Perform a plain vector search, no keyword query provided 110 res, err := processDenseSearch(ctx, denseSearch, params, modules, schemaGetter, targetVectorParamHelper) 111 if err != nil { 112 return nil, err 113 } 114 115 found = append(found, res) 116 // weight is irrelevant here, we're doing vector search only 117 weights = append(weights, 1) 118 names = append(names, "vector") 119 } else if params.SubSearches != nil { 120 ss := params.SubSearches 121 122 // To catch error if ss is empty 123 _, err := decideSearchVector(ctx, params, modules, schemaGetter, targetVectorParamHelper) 124 if err != nil { 125 return nil, err 126 } 127 128 for _, subsearch := range ss.([]searchparams.WeightedSearchResult) { 129 res, name, weight, err := handleSubSearch(ctx, &subsearch, denseSearch, sparseSearch, params, modules, schemaGetter, targetVectorParamHelper) 130 if err != nil { 131 return nil, err 132 } 133 134 if res == nil { 135 continue 136 } 137 138 found = append(found, res) 139 weights = append(weights, weight) 140 names = append(names, name) 141 } 142 } else { 143 // This should not happen, as it should be caught at the validation level, 144 // but just in case it does, we catch it here. 145 return nil, fmt.Errorf("no query, search vector, or sub-searches provided") 146 } 147 if len(weights) != len(found) { 148 return nil, fmt.Errorf("length of weights and results do not match for hybrid search %v vs. %v", len(weights), len(found)) 149 } 150 151 var fused []*search.Result 152 if params.FusionAlgorithm == common_filters.HybridRankedFusion { 153 fused = FusionRanked(weights, found, names) 154 } else if params.FusionAlgorithm == common_filters.HybridRelativeScoreFusion { 155 fused = FusionRelativeScore(weights, found, names) 156 } else { 157 return nil, fmt.Errorf("unknown ranking algorithm %v for hybrid search", params.FusionAlgorithm) 158 } 159 160 if postProc != nil { 161 sr, err := postProc(fused) 162 if err != nil { 163 return nil, fmt.Errorf("hybrid search post-processing: %w", err) 164 } 165 newResults := make([]*search.Result, len(sr)) 166 for i := range sr { 167 if err != nil { 168 return nil, fmt.Errorf("hybrid search post-processing: %w", err) 169 } 170 newResults[i] = &sr[i] 171 } 172 fused = newResults 173 } 174 if params.Autocut > 0 { 175 scores := make([]float32, len(fused)) 176 for i := range fused { 177 scores[i] = fused[i].Score 178 } 179 cutOff := autocut.Autocut(scores, params.Autocut) 180 fused = fused[:cutOff] 181 } 182 return fused, nil 183 } 184 185 func processSparseSearch(results []*storobj.Object, scores []float32, err error) ([]*search.Result, error) { 186 if err != nil { 187 return nil, fmt.Errorf("sparse search: %w", err) 188 } 189 190 out := make([]*search.Result, len(results)) 191 for i, obj := range results { 192 sr := obj.SearchResultWithScore(additional.Properties{}, scores[i]) 193 sr.SecondarySortValue = sr.Score 194 out[i] = &sr 195 } 196 return out, nil 197 } 198 199 func processDenseSearch(ctx context.Context, 200 denseSearch denseSearchFunc, params *Params, modules modulesProvider, 201 schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper, 202 ) ([]*search.Result, error) { 203 vector, err := decideSearchVector(ctx, params, modules, schemaGetter, targetVectorParamHelper) 204 if err != nil { 205 return nil, err 206 } 207 208 res, dists, err := denseSearch(vector) 209 if err != nil { 210 return nil, fmt.Errorf("dense search: %w", err) 211 } 212 213 out := make([]*search.Result, len(res)) 214 for i, obj := range res { 215 sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) 216 sr.SecondarySortValue = 1 - sr.Dist 217 out[i] = &sr 218 } 219 return out, nil 220 } 221 222 func handleSubSearch(ctx context.Context, 223 subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, sparseSearch sparseSearchFunc, 224 params *Params, modules modulesProvider, 225 schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper, 226 ) ([]*search.Result, string, float64, error) { 227 switch subsearch.Type { 228 case "bm25": 229 fallthrough 230 case "sparseSearch": 231 return sparseSubSearch(subsearch, params, sparseSearch) 232 case "nearText": 233 return nearTextSubSearch(ctx, subsearch, denseSearch, params, modules, schemaGetter, targetVectorParamHelper) 234 case "nearVector": 235 return nearVectorSubSearch(subsearch, denseSearch) 236 default: 237 return nil, "unknown", 0, fmt.Errorf("unknown hybrid search type %q", subsearch.Type) 238 } 239 } 240 241 func sparseSubSearch(subsearch *searchparams.WeightedSearchResult, params *Params, sparseSearch sparseSearchFunc) ([]*search.Result, string, float64, error) { 242 sp := subsearch.SearchParams.(searchparams.KeywordRanking) 243 params.Keyword = &sp 244 245 res, dists, err := sparseSearch() 246 if err != nil { 247 return nil, "", 0, fmt.Errorf("sparse subsearch: %w", err) 248 } 249 250 out := make([]*search.Result, len(res)) 251 for i, obj := range res { 252 sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) 253 sr.SecondarySortValue = sr.Score 254 out[i] = &sr 255 } 256 257 return out, "bm25f", subsearch.Weight, nil 258 } 259 260 func nearTextSubSearch(ctx context.Context, subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, 261 params *Params, modules modulesProvider, 262 schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper, 263 ) ([]*search.Result, string, float64, error) { 264 sp := subsearch.SearchParams.(searchparams.NearTextParams) 265 if modules == nil || schemaGetter == nil || targetVectorParamHelper == nil { 266 return nil, "", 0, nil 267 } 268 269 targetVector := getTargetVector(params.TargetVectors) 270 targetVector, err := targetVectorParamHelper.GetTargetVectorOrDefault(schemaGetter.GetSchemaSkipAuth(), 271 params.Class, targetVector) 272 if err != nil { 273 return nil, "", 0, err 274 } 275 276 vector, err := vectorFromModuleInput(ctx, params.Class, sp.Values[0], targetVector, modules) 277 if err != nil { 278 return nil, "", 0, err 279 } 280 281 res, dists, err := denseSearch(vector) 282 if err != nil { 283 return nil, "", 0, err 284 } 285 286 out := make([]*search.Result, len(res)) 287 for i, obj := range res { 288 sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) 289 sr.SecondarySortValue = 1 - sr.Dist 290 out[i] = &sr 291 } 292 293 return out, "vector,nearText", subsearch.Weight, nil 294 } 295 296 func nearVectorSubSearch(subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc) ([]*search.Result, string, float64, error) { 297 sp := subsearch.SearchParams.(searchparams.NearVector) 298 299 res, dists, err := denseSearch(sp.Vector) 300 if err != nil { 301 return nil, "", 0, err 302 } 303 304 out := make([]*search.Result, len(res)) 305 for i, obj := range res { 306 sr := obj.SearchResultWithDist(additional.Properties{}, dists[i]) 307 sr.SecondarySortValue = 1 - sr.Dist 308 out[i] = &sr 309 } 310 311 return out, "vector,nearVector", subsearch.Weight, nil 312 } 313 314 func decideSearchVector(ctx context.Context, 315 params *Params, modules modulesProvider, 316 schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper, 317 ) ([]float32, error) { 318 var ( 319 vector []float32 320 err error 321 ) 322 323 if params.Vector != nil && len(params.Vector) != 0 { 324 vector = params.Vector 325 } else { 326 if modules != nil && schemaGetter != nil && targetVectorParamHelper != nil { 327 targetVector := getTargetVector(params.TargetVectors) 328 targetVector, err = targetVectorParamHelper.GetTargetVectorOrDefault(schemaGetter.GetSchemaSkipAuth(), 329 params.Class, targetVector) 330 if err != nil { 331 return nil, err 332 } 333 vector, err = vectorFromModuleInput(ctx, params.Class, params.Query, targetVector, modules) 334 if err != nil { 335 return nil, err 336 } 337 } 338 } 339 340 return vector, nil 341 } 342 343 func vectorFromModuleInput(ctx context.Context, class, input, targetVector string, modules modulesProvider) ([]float32, error) { 344 vector, err := modules.VectorFromInput(ctx, class, input, targetVector) 345 if err != nil { 346 return nil, fmt.Errorf("get vector input from modules provider: %w", err) 347 } 348 return vector, nil 349 } 350 351 func getTargetVector(targetVectors []string) string { 352 if len(targetVectors) == 1 { 353 return targetVectors[0] 354 } 355 return "" 356 }