github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/rank.go (about)

     1  // Copyright 2023 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tsearch
    12  
    13  import (
    14  	"math"
    15  	"sort"
    16  	"strings"
    17  )
    18  
    19  // defaultWeights is the default list of weights corresponding to the tsvector
    20  // lexeme weights D, C, B, and A.
    21  var defaultWeights = [4]float32{0.1, 0.2, 0.4, 1.0}
    22  
    23  // Bitmask for the normalization integer. These define different ranking
    24  // behaviors. They're defined in Postgres in tsrank.c.
    25  // 0, the default, ignores the document length.
    26  // 1 devides the rank by 1 + the logarithm of the document length.
    27  // 2 divides the rank by the document length.
    28  // 4 divides the rank by the mean harmonic distance between extents.
    29  //
    30  //	NOTE: This is only implemented by ts_rank_cd, which is currently not
    31  //	implemented by CockroachDB. This constant is left for consistency with
    32  //	the original PostgreSQL source code.
    33  //
    34  // 8 divides the rank by the number of unique words in document.
    35  // 16 divides the rank by 1 + the logarithm of the number of unique words in document.
    36  // 32 divides the rank by itself + 1.
    37  type rankBehavior int
    38  
    39  const (
    40  	// rankNoNorm is the default. It ignores the document length.
    41  	rankNoNorm rankBehavior = 0x0
    42  	// rankNormLoglength divides the rank by 1 + the logarithm of the document length.
    43  	rankNormLoglength = 0x01
    44  	// rankNormLength divides the rank by the document length.
    45  	rankNormLength = 0x02
    46  	// rankNormExtdist divides the rank by the mean harmonic distance between extents.
    47  	// Note, this is only implemented by ts_rank_cd, which is not currently implemented
    48  	// by CockroachDB. The constant is kept for consistency with Postgres.
    49  	rankNormExtdist = 0x04
    50  	// rankNormUniq divides the rank by the number of unique words in document.
    51  	rankNormUniq = 0x08
    52  	// rankNormLoguniq divides the rank by 1 + the logarithm of the number of unique words in document.
    53  	rankNormLoguniq = 0x10
    54  	// rankNormRdivrplus1 divides the rank by itself + 1.
    55  	rankNormRdivrplus1 = 0x20
    56  )
    57  
    58  // Defeat the unused linter.
    59  var _ = rankNoNorm
    60  var _ = rankNormExtdist
    61  
    62  // cntLen returns the count of represented lexemes in a tsvector, including
    63  // the number of repeated lexemes in the vector.
    64  func cntLen(v TSVector) int {
    65  	var ret int
    66  	for i := range v {
    67  		posLen := len(v[i].positions)
    68  		if posLen > 0 {
    69  			ret += posLen
    70  		} else {
    71  			ret += 1
    72  		}
    73  	}
    74  	return ret
    75  }
    76  
    77  // Rank implements the ts_rank functionality, which ranks a tsvector against a
    78  // tsquery. The weights parameter is a list of weights corresponding to the
    79  // tsvector lexeme weights D, C, B, and A. The method parameter is a bitmask
    80  // defining different ranking behaviors, defined in the rankBehavior type
    81  // above in this file. The default ranking behavior is 0, which doesn't perform
    82  // any normalization based on the document length.
    83  //
    84  // N.B.: this function is directly translated from the calc_rank function in
    85  // tsrank.c, which contains almost no comments. As of this time, I am unable
    86  // to sufficiently explain how this ranker works, but I'm confident that the
    87  // implementation is at least compatible with Postgres.
    88  // https://github.com/postgres/postgres/blob/765f5df726918bcdcfd16bcc5418e48663d1dd59/src/backend/utils/adt/tsrank.c#L357
    89  func Rank(weights []float32, v TSVector, q TSQuery, method int) (float32, error) {
    90  	w := defaultWeights
    91  	if weights != nil {
    92  		copy(w[:4], weights[:4])
    93  	}
    94  	if len(v) == 0 || q.root == nil {
    95  		return 0, nil
    96  	}
    97  	var res float32
    98  	if q.root.op == and || q.root.op == followedby {
    99  		res = rankAnd(w, v, q)
   100  	} else {
   101  		res = rankOr(w, v, q)
   102  	}
   103  	if res < 0 {
   104  		// This constant is taken from the Postgres source code, unfortunately I
   105  		// don't understand its meaning.
   106  		res = 1e-20
   107  	}
   108  	if method&rankNormLoglength > 0 {
   109  		res /= float32(math.Log(float64(cntLen(v)+1)) / math.Log(2.0))
   110  	}
   111  
   112  	if method&rankNormLength > 0 {
   113  		l := cntLen(v)
   114  		if l > 0 {
   115  			res /= float32(l)
   116  		}
   117  	}
   118  	// rankNormExtDist is not applicable - it's only used for ts_rank_cd.
   119  
   120  	if method&rankNormUniq > 0 {
   121  		res /= float32(len(v))
   122  	}
   123  
   124  	if method&rankNormLoguniq > 0 {
   125  		res /= float32(math.Log(float64(len(v)+1)) / math.Log(2.0))
   126  	}
   127  
   128  	if method&rankNormRdivrplus1 > 0 {
   129  		res /= res + 1
   130  	}
   131  
   132  	return res, nil
   133  }
   134  
   135  func sortAndDistinctQueryTerms(q TSQuery) []*tsNode {
   136  	// Extract all leaf nodes from the query tree.
   137  	leafNodes := make([]*tsNode, 0)
   138  	var extractTerms func(q *tsNode)
   139  	extractTerms = func(q *tsNode) {
   140  		if q == nil {
   141  			return
   142  		}
   143  		if q.op != invalid {
   144  			extractTerms(q.l)
   145  			extractTerms(q.r)
   146  		} else {
   147  			leafNodes = append(leafNodes, q)
   148  		}
   149  	}
   150  	extractTerms(q.root)
   151  	// Sort the terms.
   152  	sort.Slice(leafNodes, func(i, j int) bool {
   153  		return leafNodes[i].term.lexeme < leafNodes[j].term.lexeme
   154  	})
   155  	// Then distinct: (wouldn't it be nice if Go had generics?)
   156  	lastUniqueIdx := 0
   157  	for j := 1; j < len(leafNodes); j++ {
   158  		if leafNodes[j].term.lexeme != leafNodes[lastUniqueIdx].term.lexeme {
   159  			// We found a unique entry, at index i. The last unique entry in the array
   160  			// was at lastUniqueIdx, so set the entry after that one to our new unique
   161  			// entry, and bump lastUniqueIdx for the next loop iteration.
   162  			lastUniqueIdx++
   163  			leafNodes[lastUniqueIdx] = leafNodes[j]
   164  		}
   165  	}
   166  	leafNodes = leafNodes[:lastUniqueIdx+1]
   167  	return leafNodes
   168  }
   169  
   170  // findRankMatches finds all matches for a given query term in a tsvector,
   171  // regardless of the expected query weight.
   172  // query is the term being matched. v is the tsvector being searched.
   173  // matches is a slice of matches to append to, to save on allocations as this
   174  // function is called in a loop.
   175  func findRankMatches(query *tsNode, v TSVector, matches [][]tsPosition) [][]tsPosition {
   176  	target := query.term.lexeme
   177  	i := sort.Search(len(v), func(i int) bool {
   178  		return v[i].lexeme >= target
   179  	})
   180  	if i >= len(v) {
   181  		return matches
   182  	}
   183  	if query.term.isPrefixMatch() {
   184  		for j := i; j < len(v); j++ {
   185  			t := v[j]
   186  			if !strings.HasPrefix(t.lexeme, target) {
   187  				break
   188  			}
   189  			matches = append(matches, t.positions)
   190  		}
   191  	} else if v[i].lexeme == target {
   192  		matches = append(matches, v[i].positions)
   193  	}
   194  	return matches
   195  }
   196  
   197  // rankOr computes the rank for a query with an OR operator at its root.
   198  // It takes the same parameters as TSRank.
   199  func rankOr(weights [4]float32, v TSVector, q TSQuery) float32 {
   200  	queryLeaves := sortAndDistinctQueryTerms(q)
   201  	var matches = make([][]tsPosition, 0)
   202  	var res float32
   203  	for i := range queryLeaves {
   204  		matches = matches[:0]
   205  		matches = findRankMatches(queryLeaves[i], v, matches)
   206  		if len(matches) == 0 {
   207  			continue
   208  		}
   209  		resj := float32(0.0)
   210  		wjm := float32(-1.0)
   211  		jm := 0
   212  		for _, innerMatches := range matches {
   213  			for j, pos := range innerMatches {
   214  				termWeight := pos.weight.val()
   215  				weight := weights[termWeight]
   216  				resj = resj + weight/float32((j+1)*(j+1))
   217  				if weight > wjm {
   218  					wjm = weight
   219  					jm = j
   220  				}
   221  			}
   222  		}
   223  		// Explanation from Postgres tsrank.c:
   224  		// limit (sum(1/i^2),i=1,inf) = pi^2/6
   225  		// resj = sum(wi/i^2),i=1,noccurence,
   226  		// wi - should be sorted desc,
   227  		// don't sort for now, just choose maximum weight. This should be corrected
   228  		// Oleg Bartunov
   229  		res = res + (wjm+resj-wjm/float32((jm+1)*(jm+1)))/1.64493406685
   230  	}
   231  	if len(queryLeaves) > 0 {
   232  		res /= float32(len(queryLeaves))
   233  	}
   234  	return res
   235  }
   236  
   237  // rankAnd computes the rank for a query with an AND or followed-by operator at
   238  // its root. It takes the same parameters as TSRank.
   239  func rankAnd(weights [4]float32, v TSVector, q TSQuery) float32 {
   240  	queryLeaves := sortAndDistinctQueryTerms(q)
   241  	if len(queryLeaves) < 2 {
   242  		return rankOr(weights, v, q)
   243  	}
   244  	pos := make([][]tsPosition, len(queryLeaves))
   245  	res := float32(-1)
   246  	var matches = make([][]tsPosition, 0)
   247  	for i := range queryLeaves {
   248  		matches = matches[:0]
   249  		matches = findRankMatches(queryLeaves[i], v, matches)
   250  		for _, innerMatches := range matches {
   251  			pos[i] = innerMatches
   252  			// Loop back through the earlier position matches
   253  			for k := 0; k < i; k++ {
   254  				if pos[k] == nil {
   255  					continue
   256  				}
   257  				for l := range pos[i] {
   258  					// For each of the earlier matches
   259  					for p := range pos[k] {
   260  						dist := int(pos[i][l].position) - int(pos[k][p].position)
   261  						if dist < 0 {
   262  							dist = -dist
   263  						}
   264  						if dist != 0 {
   265  							curw := float32(math.Sqrt(float64(weights[pos[i][l].weight.val()] * weights[pos[k][p].weight.val()] * wordDistance(dist))))
   266  							if res < 0 {
   267  								res = curw
   268  							} else {
   269  								res = 1.0 - (1.0-res)*(1.0-curw)
   270  							}
   271  						}
   272  					}
   273  				}
   274  			}
   275  		}
   276  	}
   277  	return res
   278  }
   279  
   280  // Returns a weight of a word collocation. See Postgres tsrank.c.
   281  func wordDistance(dist int) float32 {
   282  	if dist > 100 {
   283  		return 1e-30
   284  	}
   285  	return float32(1.0 / (1.005 + 0.05*math.Exp(float64(float32(dist)/1.5-2))))
   286  }