github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/searcher.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hybrid
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  
    18  	"github.com/sirupsen/logrus"
    19  	"github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters"
    20  	"github.com/weaviate/weaviate/entities/additional"
    21  	"github.com/weaviate/weaviate/entities/autocut"
    22  	"github.com/weaviate/weaviate/entities/schema"
    23  	"github.com/weaviate/weaviate/entities/search"
    24  	"github.com/weaviate/weaviate/entities/searchparams"
    25  	"github.com/weaviate/weaviate/entities/storobj"
    26  	uc "github.com/weaviate/weaviate/usecases/schema"
    27  )
    28  
    29  const DefaultLimit = 100
    30  
    31  type Params struct {
    32  	*searchparams.HybridSearch
    33  	Keyword *searchparams.KeywordRanking
    34  	Class   string
    35  	Autocut int
    36  }
    37  
    38  // Result facilitates the pairing of a search result with its internal doc id.
    39  //
    40  // This type is key in generalising hybrid search across different use cases.
    41  // Some use cases require a full search result (Get{} queries) and others need
    42  // only a doc id (Aggregate{}) which the search.Result type does not contain.
    43  // It does now
    44  
    45  type Results []*search.Result
    46  
    47  // sparseSearchFunc is the signature of a closure which performs sparse search.
    48  // Any package which wishes use hybrid search must provide this. The weights are
    49  // used in calculating the final scores of the result set.
    50  type sparseSearchFunc func() (results []*storobj.Object, weights []float32, err error)
    51  
    52  // denseSearchFunc is the signature of a closure which performs dense search.
    53  // A search vector argument is required to pass along to the vector index.
    54  // Any package which wishes use hybrid search must provide this The weights are
    55  // used in calculating the final scores of the result set.
    56  type denseSearchFunc func(searchVector []float32) (results []*storobj.Object, weights []float32, err error)
    57  
    58  // postProcFunc takes the results of the hybrid search and applies some transformation.
    59  // This is optionally provided, and allows the caller to somehow change the nature of
    60  // the result set. For example, Get{} queries sometimes require resolving references,
    61  // which is implemented by doing the reference resolution within a postProcFunc closure.
    62  type postProcFunc func(hybridResults []*search.Result) (postProcResults []search.Result, err error)
    63  
    64  type modulesProvider interface {
    65  	VectorFromInput(ctx context.Context,
    66  		className, input, targetVector string) ([]float32, error)
    67  }
    68  
    69  type targetVectorParamHelper interface {
    70  	GetTargetVectorOrDefault(sch schema.Schema, className, targetVector string) (string, error)
    71  }
    72  
    73  // Search executes sparse and dense searches and combines the result sets using Reciprocal Rank Fusion
    74  func Search(ctx context.Context, params *Params, logger logrus.FieldLogger, sparseSearch sparseSearchFunc,
    75  	denseSearch denseSearchFunc, postProc postProcFunc, modules modulesProvider,
    76  	schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper,
    77  ) ([]*search.Result, error) {
    78  	var (
    79  		found   [][]*search.Result
    80  		weights []float64
    81  		names   []string
    82  	)
    83  
    84  	if params.Query != "" {
    85  		alpha := params.Alpha
    86  
    87  		if alpha < 1 {
    88  			res, err := processSparseSearch(sparseSearch())
    89  			if err != nil {
    90  				return nil, err
    91  			}
    92  
    93  			found = append(found, res)
    94  			weights = append(weights, 1-alpha)
    95  			names = append(names, "keyword")
    96  		}
    97  
    98  		if alpha > 0 {
    99  			res, err := processDenseSearch(ctx, denseSearch, params, modules, schemaGetter, targetVectorParamHelper)
   100  			if err != nil {
   101  				return nil, err
   102  			}
   103  
   104  			found = append(found, res)
   105  			weights = append(weights, alpha)
   106  			names = append(names, "vector")
   107  		}
   108  	} else if params.Vector != nil {
   109  		// Perform a plain vector search, no keyword query provided
   110  		res, err := processDenseSearch(ctx, denseSearch, params, modules, schemaGetter, targetVectorParamHelper)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  
   115  		found = append(found, res)
   116  		// weight is irrelevant here, we're doing vector search only
   117  		weights = append(weights, 1)
   118  		names = append(names, "vector")
   119  	} else if params.SubSearches != nil {
   120  		ss := params.SubSearches
   121  
   122  		// To catch error if ss is empty
   123  		_, err := decideSearchVector(ctx, params, modules, schemaGetter, targetVectorParamHelper)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  
   128  		for _, subsearch := range ss.([]searchparams.WeightedSearchResult) {
   129  			res, name, weight, err := handleSubSearch(ctx, &subsearch, denseSearch, sparseSearch, params, modules, schemaGetter, targetVectorParamHelper)
   130  			if err != nil {
   131  				return nil, err
   132  			}
   133  
   134  			if res == nil {
   135  				continue
   136  			}
   137  
   138  			found = append(found, res)
   139  			weights = append(weights, weight)
   140  			names = append(names, name)
   141  		}
   142  	} else {
   143  		// This should not happen, as it should be caught at the validation level,
   144  		// but just in case it does, we catch it here.
   145  		return nil, fmt.Errorf("no query, search vector, or sub-searches provided")
   146  	}
   147  	if len(weights) != len(found) {
   148  		return nil, fmt.Errorf("length of weights and results do not match for hybrid search %v vs. %v", len(weights), len(found))
   149  	}
   150  
   151  	var fused []*search.Result
   152  	if params.FusionAlgorithm == common_filters.HybridRankedFusion {
   153  		fused = FusionRanked(weights, found, names)
   154  	} else if params.FusionAlgorithm == common_filters.HybridRelativeScoreFusion {
   155  		fused = FusionRelativeScore(weights, found, names)
   156  	} else {
   157  		return nil, fmt.Errorf("unknown ranking algorithm %v for hybrid search", params.FusionAlgorithm)
   158  	}
   159  
   160  	if postProc != nil {
   161  		sr, err := postProc(fused)
   162  		if err != nil {
   163  			return nil, fmt.Errorf("hybrid search post-processing: %w", err)
   164  		}
   165  		newResults := make([]*search.Result, len(sr))
   166  		for i := range sr {
   167  			if err != nil {
   168  				return nil, fmt.Errorf("hybrid search post-processing: %w", err)
   169  			}
   170  			newResults[i] = &sr[i]
   171  		}
   172  		fused = newResults
   173  	}
   174  	if params.Autocut > 0 {
   175  		scores := make([]float32, len(fused))
   176  		for i := range fused {
   177  			scores[i] = fused[i].Score
   178  		}
   179  		cutOff := autocut.Autocut(scores, params.Autocut)
   180  		fused = fused[:cutOff]
   181  	}
   182  	return fused, nil
   183  }
   184  
   185  func processSparseSearch(results []*storobj.Object, scores []float32, err error) ([]*search.Result, error) {
   186  	if err != nil {
   187  		return nil, fmt.Errorf("sparse search: %w", err)
   188  	}
   189  
   190  	out := make([]*search.Result, len(results))
   191  	for i, obj := range results {
   192  		sr := obj.SearchResultWithScore(additional.Properties{}, scores[i])
   193  		sr.SecondarySortValue = sr.Score
   194  		out[i] = &sr
   195  	}
   196  	return out, nil
   197  }
   198  
   199  func processDenseSearch(ctx context.Context,
   200  	denseSearch denseSearchFunc, params *Params, modules modulesProvider,
   201  	schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper,
   202  ) ([]*search.Result, error) {
   203  	vector, err := decideSearchVector(ctx, params, modules, schemaGetter, targetVectorParamHelper)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	res, dists, err := denseSearch(vector)
   209  	if err != nil {
   210  		return nil, fmt.Errorf("dense search: %w", err)
   211  	}
   212  
   213  	out := make([]*search.Result, len(res))
   214  	for i, obj := range res {
   215  		sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
   216  		sr.SecondarySortValue = 1 - sr.Dist
   217  		out[i] = &sr
   218  	}
   219  	return out, nil
   220  }
   221  
   222  func handleSubSearch(ctx context.Context,
   223  	subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc, sparseSearch sparseSearchFunc,
   224  	params *Params, modules modulesProvider,
   225  	schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper,
   226  ) ([]*search.Result, string, float64, error) {
   227  	switch subsearch.Type {
   228  	case "bm25":
   229  		fallthrough
   230  	case "sparseSearch":
   231  		return sparseSubSearch(subsearch, params, sparseSearch)
   232  	case "nearText":
   233  		return nearTextSubSearch(ctx, subsearch, denseSearch, params, modules, schemaGetter, targetVectorParamHelper)
   234  	case "nearVector":
   235  		return nearVectorSubSearch(subsearch, denseSearch)
   236  	default:
   237  		return nil, "unknown", 0, fmt.Errorf("unknown hybrid search type %q", subsearch.Type)
   238  	}
   239  }
   240  
   241  func sparseSubSearch(subsearch *searchparams.WeightedSearchResult, params *Params, sparseSearch sparseSearchFunc) ([]*search.Result, string, float64, error) {
   242  	sp := subsearch.SearchParams.(searchparams.KeywordRanking)
   243  	params.Keyword = &sp
   244  
   245  	res, dists, err := sparseSearch()
   246  	if err != nil {
   247  		return nil, "", 0, fmt.Errorf("sparse subsearch: %w", err)
   248  	}
   249  
   250  	out := make([]*search.Result, len(res))
   251  	for i, obj := range res {
   252  		sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
   253  		sr.SecondarySortValue = sr.Score
   254  		out[i] = &sr
   255  	}
   256  
   257  	return out, "bm25f", subsearch.Weight, nil
   258  }
   259  
   260  func nearTextSubSearch(ctx context.Context, subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc,
   261  	params *Params, modules modulesProvider,
   262  	schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper,
   263  ) ([]*search.Result, string, float64, error) {
   264  	sp := subsearch.SearchParams.(searchparams.NearTextParams)
   265  	if modules == nil || schemaGetter == nil || targetVectorParamHelper == nil {
   266  		return nil, "", 0, nil
   267  	}
   268  
   269  	targetVector := getTargetVector(params.TargetVectors)
   270  	targetVector, err := targetVectorParamHelper.GetTargetVectorOrDefault(schemaGetter.GetSchemaSkipAuth(),
   271  		params.Class, targetVector)
   272  	if err != nil {
   273  		return nil, "", 0, err
   274  	}
   275  
   276  	vector, err := vectorFromModuleInput(ctx, params.Class, sp.Values[0], targetVector, modules)
   277  	if err != nil {
   278  		return nil, "", 0, err
   279  	}
   280  
   281  	res, dists, err := denseSearch(vector)
   282  	if err != nil {
   283  		return nil, "", 0, err
   284  	}
   285  
   286  	out := make([]*search.Result, len(res))
   287  	for i, obj := range res {
   288  		sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
   289  		sr.SecondarySortValue = 1 - sr.Dist
   290  		out[i] = &sr
   291  	}
   292  
   293  	return out, "vector,nearText", subsearch.Weight, nil
   294  }
   295  
   296  func nearVectorSubSearch(subsearch *searchparams.WeightedSearchResult, denseSearch denseSearchFunc) ([]*search.Result, string, float64, error) {
   297  	sp := subsearch.SearchParams.(searchparams.NearVector)
   298  
   299  	res, dists, err := denseSearch(sp.Vector)
   300  	if err != nil {
   301  		return nil, "", 0, err
   302  	}
   303  
   304  	out := make([]*search.Result, len(res))
   305  	for i, obj := range res {
   306  		sr := obj.SearchResultWithDist(additional.Properties{}, dists[i])
   307  		sr.SecondarySortValue = 1 - sr.Dist
   308  		out[i] = &sr
   309  	}
   310  
   311  	return out, "vector,nearVector", subsearch.Weight, nil
   312  }
   313  
   314  func decideSearchVector(ctx context.Context,
   315  	params *Params, modules modulesProvider,
   316  	schemaGetter uc.SchemaGetter, targetVectorParamHelper targetVectorParamHelper,
   317  ) ([]float32, error) {
   318  	var (
   319  		vector []float32
   320  		err    error
   321  	)
   322  
   323  	if params.Vector != nil && len(params.Vector) != 0 {
   324  		vector = params.Vector
   325  	} else {
   326  		if modules != nil && schemaGetter != nil && targetVectorParamHelper != nil {
   327  			targetVector := getTargetVector(params.TargetVectors)
   328  			targetVector, err = targetVectorParamHelper.GetTargetVectorOrDefault(schemaGetter.GetSchemaSkipAuth(),
   329  				params.Class, targetVector)
   330  			if err != nil {
   331  				return nil, err
   332  			}
   333  			vector, err = vectorFromModuleInput(ctx, params.Class, params.Query, targetVector, modules)
   334  			if err != nil {
   335  				return nil, err
   336  			}
   337  		}
   338  	}
   339  
   340  	return vector, nil
   341  }
   342  
   343  func vectorFromModuleInput(ctx context.Context, class, input, targetVector string, modules modulesProvider) ([]float32, error) {
   344  	vector, err := modules.VectorFromInput(ctx, class, input, targetVector)
   345  	if err != nil {
   346  		return nil, fmt.Errorf("get vector input from modules provider: %w", err)
   347  	}
   348  	return vector, nil
   349  }
   350  
   351  func getTargetVector(targetVectors []string) string {
   352  	if len(targetVectors) == 1 {
   353  		return targetVectors[0]
   354  	}
   355  	return ""
   356  }