github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/hybrid_fusion.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hybrid
    13  
    14  import (
    15  	"fmt"
    16  	"sort"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/weaviate/weaviate/entities/search"
    20  )
    21  
    22  func FusionRanked(weights []float64, resultSets [][]*search.Result, setNames []string) []*search.Result {
    23  	combinedResults := map[strfmt.UUID]*search.Result{}
    24  	for resultSetIndex, resultSet := range resultSets {
    25  		for i, res := range resultSet {
    26  			if res.DocID == nil {
    27  				panic("doc id is nil")
    28  			}
    29  			tempResult := res
    30  			docId := tempResult.ID
    31  			score := weights[resultSetIndex] / float64(i+60) // TODO replace 60 with a class configured variable in the schema
    32  
    33  			if tempResult.AdditionalProperties == nil {
    34  				tempResult.AdditionalProperties = map[string]interface{}{}
    35  			}
    36  
    37  			// Get previous results from the map, if any
    38  			previousResult, ok := combinedResults[docId]
    39  			if ok {
    40  				tempResult.AdditionalProperties["explainScore"] = fmt.Sprintf(
    41  					"%v\nHybrid (Result Set %v) Document %v contributed %v to the score",
    42  					previousResult.AdditionalProperties["explainScore"], setNames[resultSetIndex], tempResult.ID, score)
    43  				score += float64(previousResult.Score)
    44  			} else {
    45  				tempResult.AdditionalProperties["explainScore"] = fmt.Sprintf(
    46  					"%v\nHybrid (Result Set %v) Document %v contributed %v to the score",
    47  					tempResult.ExplainScore, setNames[resultSetIndex], tempResult.ID, score)
    48  			}
    49  			tempResult.AdditionalProperties["rank_score"] = score
    50  			tempResult.AdditionalProperties["score"] = score
    51  
    52  			tempResult.Score = float32(score)
    53  			combinedResults[docId] = tempResult
    54  		}
    55  	}
    56  
    57  	// Sort the results
    58  	var (
    59  		concat = make([]*search.Result, len(combinedResults))
    60  		i      = 0
    61  	)
    62  	for _, res := range combinedResults {
    63  		res.ExplainScore = res.AdditionalProperties["explainScore"].(string)
    64  		concat[i] = res
    65  		i++
    66  	}
    67  
    68  	sort.Slice(concat, func(i, j int) bool {
    69  		if concat[j].Score == concat[i].Score {
    70  			return concat[i].SecondarySortValue > concat[j].SecondarySortValue
    71  		}
    72  		return float64(concat[i].Score) > float64(concat[j].Score)
    73  	})
    74  	return concat
    75  }
    76  
    77  // FusionRelativeScore uses the relative differences in the scores from keyword and vector search to combine the
    78  // results. This method retains more information than ranked fusion and should result in better results.
    79  //
    80  // The scores from each result are normalized between 0 and 1, e.g. the maximum score becomes 1 and the minimum 0 and the
    81  // other scores are in between, keeping their relative distance to the other scores.
    82  // Example:
    83  //
    84  //	Input score = [1, 8, 6, 11] => [0, 0.7, 0.5, 1]
    85  //
    86  // The normalized scores are then combined using their respective weight and the combined scores are sorted
    87  func FusionRelativeScore(weights []float64, resultSets [][]*search.Result, names []string) []*search.Result {
    88  	if len(resultSets[0]) == 0 && (len(resultSets) == 1 || len(resultSets[1]) == 0) {
    89  		return []*search.Result{}
    90  	}
    91  
    92  	var maximum []float32
    93  	var minimum []float32
    94  
    95  	for i := range resultSets {
    96  		if len(resultSets[i]) > 0 {
    97  			maximum = append(maximum, resultSets[i][0].SecondarySortValue)
    98  			minimum = append(minimum, resultSets[i][0].SecondarySortValue)
    99  		} else { // dummy values so the indices match
   100  			maximum = append(maximum, 0)
   101  			minimum = append(minimum, 0)
   102  		}
   103  		for _, res := range resultSets[i] {
   104  			if res.SecondarySortValue > maximum[i] {
   105  				maximum[i] = res.SecondarySortValue
   106  			}
   107  
   108  			if res.SecondarySortValue < minimum[i] {
   109  				minimum[i] = res.SecondarySortValue
   110  			}
   111  		}
   112  	}
   113  
   114  	// normalize scores between 0 and 1 and sum up the normalized scores from different sources
   115  	// pre-allocate map, at this stage we do not know how many total, combined results there are, but it is at least the
   116  	// length of the longer input list
   117  	numResults := len(resultSets[0])
   118  	if len(resultSets) > 1 && len(resultSets[1]) > numResults {
   119  		numResults = len(resultSets[1])
   120  	}
   121  	mapResults := make(map[strfmt.UUID]*search.Result, numResults)
   122  	for i := range resultSets {
   123  		weight := float32(weights[i])
   124  		for _, res := range resultSets[i] {
   125  			// If all scores are identical min and max are the same => just set score to the weight.
   126  			score := weight
   127  			if maximum[i] != minimum[i] {
   128  				score *= (res.SecondarySortValue - minimum[i]) / (maximum[i] - minimum[i])
   129  			}
   130  
   131  			previousResult, ok := mapResults[res.ID]
   132  			explainScore := fmt.Sprintf("Hybrid (Result Set %v) Document %v: original score %v, normalized score: %v", names[i], res.ID, res.SecondarySortValue, score)
   133  			if ok {
   134  				score += previousResult.Score
   135  				explainScore += " - " + previousResult.ExplainScore
   136  			}
   137  			res.Score = score
   138  			res.ExplainScore = res.ExplainScore + "\n" + explainScore
   139  
   140  			mapResults[res.ID] = res
   141  		}
   142  	}
   143  
   144  	concat := make([]*search.Result, 0, len(mapResults))
   145  	for _, res := range mapResults {
   146  		concat = append(concat, res)
   147  	}
   148  
   149  	sort.Slice(concat, func(i, j int) bool {
   150  		a_b := float64(concat[j].Score - concat[i].Score)
   151  		if a_b*a_b < 1e-14 {
   152  			return concat[i].SecondarySortValue > concat[j].SecondarySortValue
   153  		}
   154  		return float64(concat[i].Score) > float64(concat[j].Score)
   155  	})
   156  	return concat
   157  }