github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/rank.go (about) 1 // Copyright 2023 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package tsearch 12 13 import ( 14 "math" 15 "sort" 16 "strings" 17 ) 18 19 // defaultWeights is the default list of weights corresponding to the tsvector 20 // lexeme weights D, C, B, and A. 21 var defaultWeights = [4]float32{0.1, 0.2, 0.4, 1.0} 22 23 // Bitmask for the normalization integer. These define different ranking 24 // behaviors. They're defined in Postgres in tsrank.c. 25 // 0, the default, ignores the document length. 26 // 1 devides the rank by 1 + the logarithm of the document length. 27 // 2 divides the rank by the document length. 28 // 4 divides the rank by the mean harmonic distance between extents. 29 // 30 // NOTE: This is only implemented by ts_rank_cd, which is currently not 31 // implemented by CockroachDB. This constant is left for consistency with 32 // the original PostgreSQL source code. 33 // 34 // 8 divides the rank by the number of unique words in document. 35 // 16 divides the rank by 1 + the logarithm of the number of unique words in document. 36 // 32 divides the rank by itself + 1. 37 type rankBehavior int 38 39 const ( 40 // rankNoNorm is the default. It ignores the document length. 41 rankNoNorm rankBehavior = 0x0 42 // rankNormLoglength divides the rank by 1 + the logarithm of the document length. 43 rankNormLoglength = 0x01 44 // rankNormLength divides the rank by the document length. 45 rankNormLength = 0x02 46 // rankNormExtdist divides the rank by the mean harmonic distance between extents. 47 // Note, this is only implemented by ts_rank_cd, which is not currently implemented 48 // by CockroachDB. The constant is kept for consistency with Postgres. 49 rankNormExtdist = 0x04 50 // rankNormUniq divides the rank by the number of unique words in document. 51 rankNormUniq = 0x08 52 // rankNormLoguniq divides the rank by 1 + the logarithm of the number of unique words in document. 53 rankNormLoguniq = 0x10 54 // rankNormRdivrplus1 divides the rank by itself + 1. 55 rankNormRdivrplus1 = 0x20 56 ) 57 58 // Defeat the unused linter. 59 var _ = rankNoNorm 60 var _ = rankNormExtdist 61 62 // cntLen returns the count of represented lexemes in a tsvector, including 63 // the number of repeated lexemes in the vector. 64 func cntLen(v TSVector) int { 65 var ret int 66 for i := range v { 67 posLen := len(v[i].positions) 68 if posLen > 0 { 69 ret += posLen 70 } else { 71 ret += 1 72 } 73 } 74 return ret 75 } 76 77 // Rank implements the ts_rank functionality, which ranks a tsvector against a 78 // tsquery. The weights parameter is a list of weights corresponding to the 79 // tsvector lexeme weights D, C, B, and A. The method parameter is a bitmask 80 // defining different ranking behaviors, defined in the rankBehavior type 81 // above in this file. The default ranking behavior is 0, which doesn't perform 82 // any normalization based on the document length. 83 // 84 // N.B.: this function is directly translated from the calc_rank function in 85 // tsrank.c, which contains almost no comments. As of this time, I am unable 86 // to sufficiently explain how this ranker works, but I'm confident that the 87 // implementation is at least compatible with Postgres. 88 // https://github.com/postgres/postgres/blob/765f5df726918bcdcfd16bcc5418e48663d1dd59/src/backend/utils/adt/tsrank.c#L357 89 func Rank(weights []float32, v TSVector, q TSQuery, method int) (float32, error) { 90 w := defaultWeights 91 if weights != nil { 92 copy(w[:4], weights[:4]) 93 } 94 if len(v) == 0 || q.root == nil { 95 return 0, nil 96 } 97 var res float32 98 if q.root.op == and || q.root.op == followedby { 99 res = rankAnd(w, v, q) 100 } else { 101 res = rankOr(w, v, q) 102 } 103 if res < 0 { 104 // This constant is taken from the Postgres source code, unfortunately I 105 // don't understand its meaning. 106 res = 1e-20 107 } 108 if method&rankNormLoglength > 0 { 109 res /= float32(math.Log(float64(cntLen(v)+1)) / math.Log(2.0)) 110 } 111 112 if method&rankNormLength > 0 { 113 l := cntLen(v) 114 if l > 0 { 115 res /= float32(l) 116 } 117 } 118 // rankNormExtDist is not applicable - it's only used for ts_rank_cd. 119 120 if method&rankNormUniq > 0 { 121 res /= float32(len(v)) 122 } 123 124 if method&rankNormLoguniq > 0 { 125 res /= float32(math.Log(float64(len(v)+1)) / math.Log(2.0)) 126 } 127 128 if method&rankNormRdivrplus1 > 0 { 129 res /= res + 1 130 } 131 132 return res, nil 133 } 134 135 func sortAndDistinctQueryTerms(q TSQuery) []*tsNode { 136 // Extract all leaf nodes from the query tree. 137 leafNodes := make([]*tsNode, 0) 138 var extractTerms func(q *tsNode) 139 extractTerms = func(q *tsNode) { 140 if q == nil { 141 return 142 } 143 if q.op != invalid { 144 extractTerms(q.l) 145 extractTerms(q.r) 146 } else { 147 leafNodes = append(leafNodes, q) 148 } 149 } 150 extractTerms(q.root) 151 // Sort the terms. 152 sort.Slice(leafNodes, func(i, j int) bool { 153 return leafNodes[i].term.lexeme < leafNodes[j].term.lexeme 154 }) 155 // Then distinct: (wouldn't it be nice if Go had generics?) 156 lastUniqueIdx := 0 157 for j := 1; j < len(leafNodes); j++ { 158 if leafNodes[j].term.lexeme != leafNodes[lastUniqueIdx].term.lexeme { 159 // We found a unique entry, at index i. The last unique entry in the array 160 // was at lastUniqueIdx, so set the entry after that one to our new unique 161 // entry, and bump lastUniqueIdx for the next loop iteration. 162 lastUniqueIdx++ 163 leafNodes[lastUniqueIdx] = leafNodes[j] 164 } 165 } 166 leafNodes = leafNodes[:lastUniqueIdx+1] 167 return leafNodes 168 } 169 170 // findRankMatches finds all matches for a given query term in a tsvector, 171 // regardless of the expected query weight. 172 // query is the term being matched. v is the tsvector being searched. 173 // matches is a slice of matches to append to, to save on allocations as this 174 // function is called in a loop. 175 func findRankMatches(query *tsNode, v TSVector, matches [][]tsPosition) [][]tsPosition { 176 target := query.term.lexeme 177 i := sort.Search(len(v), func(i int) bool { 178 return v[i].lexeme >= target 179 }) 180 if i >= len(v) { 181 return matches 182 } 183 if query.term.isPrefixMatch() { 184 for j := i; j < len(v); j++ { 185 t := v[j] 186 if !strings.HasPrefix(t.lexeme, target) { 187 break 188 } 189 matches = append(matches, t.positions) 190 } 191 } else if v[i].lexeme == target { 192 matches = append(matches, v[i].positions) 193 } 194 return matches 195 } 196 197 // rankOr computes the rank for a query with an OR operator at its root. 198 // It takes the same parameters as TSRank. 199 func rankOr(weights [4]float32, v TSVector, q TSQuery) float32 { 200 queryLeaves := sortAndDistinctQueryTerms(q) 201 var matches = make([][]tsPosition, 0) 202 var res float32 203 for i := range queryLeaves { 204 matches = matches[:0] 205 matches = findRankMatches(queryLeaves[i], v, matches) 206 if len(matches) == 0 { 207 continue 208 } 209 resj := float32(0.0) 210 wjm := float32(-1.0) 211 jm := 0 212 for _, innerMatches := range matches { 213 for j, pos := range innerMatches { 214 termWeight := pos.weight.val() 215 weight := weights[termWeight] 216 resj = resj + weight/float32((j+1)*(j+1)) 217 if weight > wjm { 218 wjm = weight 219 jm = j 220 } 221 } 222 } 223 // Explanation from Postgres tsrank.c: 224 // limit (sum(1/i^2),i=1,inf) = pi^2/6 225 // resj = sum(wi/i^2),i=1,noccurence, 226 // wi - should be sorted desc, 227 // don't sort for now, just choose maximum weight. This should be corrected 228 // Oleg Bartunov 229 res = res + (wjm+resj-wjm/float32((jm+1)*(jm+1)))/1.64493406685 230 } 231 if len(queryLeaves) > 0 { 232 res /= float32(len(queryLeaves)) 233 } 234 return res 235 } 236 237 // rankAnd computes the rank for a query with an AND or followed-by operator at 238 // its root. It takes the same parameters as TSRank. 239 func rankAnd(weights [4]float32, v TSVector, q TSQuery) float32 { 240 queryLeaves := sortAndDistinctQueryTerms(q) 241 if len(queryLeaves) < 2 { 242 return rankOr(weights, v, q) 243 } 244 pos := make([][]tsPosition, len(queryLeaves)) 245 res := float32(-1) 246 var matches = make([][]tsPosition, 0) 247 for i := range queryLeaves { 248 matches = matches[:0] 249 matches = findRankMatches(queryLeaves[i], v, matches) 250 for _, innerMatches := range matches { 251 pos[i] = innerMatches 252 // Loop back through the earlier position matches 253 for k := 0; k < i; k++ { 254 if pos[k] == nil { 255 continue 256 } 257 for l := range pos[i] { 258 // For each of the earlier matches 259 for p := range pos[k] { 260 dist := int(pos[i][l].position) - int(pos[k][p].position) 261 if dist < 0 { 262 dist = -dist 263 } 264 if dist != 0 { 265 curw := float32(math.Sqrt(float64(weights[pos[i][l].weight.val()] * weights[pos[k][p].weight.val()] * wordDistance(dist)))) 266 if res < 0 { 267 res = curw 268 } else { 269 res = 1.0 - (1.0-res)*(1.0-curw) 270 } 271 } 272 } 273 } 274 } 275 } 276 } 277 return res 278 } 279 280 // Returns a weight of a word collocation. See Postgres tsrank.c. 281 func wordDistance(dist int) float32 { 282 if dist > 100 { 283 return 1e-30 284 } 285 return float32(1.0 / (1.005 + 0.05*math.Exp(float64(float32(dist)/1.5-2)))) 286 }