github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/eval.go (about) 1 // Copyright 2022 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package tsearch 12 13 import ( 14 "math" 15 "sort" 16 "strings" 17 18 "github.com/cockroachdb/errors" 19 ) 20 21 // EvalTSQuery runs the provided TSQuery against the provided TSVector, 22 // returning whether or not the query matches the vector. 23 func EvalTSQuery(q TSQuery, v TSVector) (bool, error) { 24 evaluator := tsEvaluator{ 25 v: v, 26 q: q, 27 } 28 return evaluator.eval() 29 } 30 31 type tsEvaluator struct { 32 v TSVector 33 q TSQuery 34 } 35 36 func (e *tsEvaluator) eval() (bool, error) { 37 return e.evalNode(e.q.root) 38 } 39 40 // evalNode is used to evaluate a query node that's not nested within any 41 // followed by operators. it returns true if the match was successful. 42 func (e *tsEvaluator) evalNode(node *tsNode) (bool, error) { 43 switch node.op { 44 case invalid: 45 // If there's no operator we're evaluating a leaf term. 46 prefixMatch := false 47 targetWeight := weightAny 48 if len(node.term.positions) > 0 { 49 targetWeight = node.term.positions[0].weight 50 if targetWeight&weightStar > 0 { 51 prefixMatch = true 52 // Unset the prefix match. 53 targetWeight = node.term.positions[0].weight & ^weightStar 54 } 55 // If no flags are set we can match anything. 56 if targetWeight == 0 { 57 targetWeight = weightAny 58 } 59 } 60 61 // To evaluate a term, we search the vector for a match. 62 target := node.term.lexeme 63 i := sort.Search(len(e.v), func(i int) bool { 64 return e.v[i].lexeme >= target 65 }) 66 if !prefixMatch && i < len(e.v) { 67 return e.v[i].lexeme == target && e.v[i].matchesWeight(targetWeight), nil 68 } 69 for ; i < len(e.v); i++ { 70 t := e.v[i] 71 // If we're prefix matching, continue searching until we either run out 72 // of prefix matches or find one that matches the weight in question. 73 if !strings.HasPrefix(t.lexeme, target) { 74 break 75 } 76 if t.matchesWeight(targetWeight) { 77 return true, nil 78 } 79 } 80 return false, nil 81 case and: 82 // Match if both operands are true. 83 l, err := e.evalNode(node.l) 84 if err != nil || !l { 85 return false, err 86 } 87 return e.evalNode(node.r) 88 case or: 89 // Match if either operand is true. 90 l, err := e.evalNode(node.l) 91 if err != nil || l { 92 return l, err 93 } 94 return e.evalNode(node.r) 95 case not: 96 // Match if the operand is false. 97 ret, err := e.evalNode(node.l) 98 return !ret, err 99 case followedby: 100 // For followed-by queries, we recurse into the special followed-by handler. 101 // Then, we return true if there is at least one position at which the 102 // followed-by query matches. 103 positions, err := e.evalWithinFollowedBy(node) 104 return positions.res, err 105 } 106 return false, errors.AssertionFailedf("invalid operator %d", node.op) 107 } 108 109 // tsPositionSet keeps track of metadata for a followed-by match. It's used to 110 // pass information about followed by queries during evaluation of them. 111 type tsPositionSet struct { 112 // positions is the list of positions that the match is successful at (or, 113 // if invert is true, unsuccessful at). 114 positions []tsPosition 115 // width is the width of the match. This is important to track to deal with 116 // chained followed by queries with possibly different widths (<-> vs <2> etc). 117 // A match of a single term within a followed by has width 0. 118 width int 119 // invert, if true, indicates that this match should be inverted. It's used 120 // to handle followed by matches within not operators. 121 invert bool 122 123 // res indicates that this match found positive results. 124 res bool 125 126 // noPos indicates that this match was missing position information. 127 noPos bool 128 } 129 130 // emitMode is a bitfield that controls the output of followed by matches. 131 type emitMode int 132 133 const ( 134 // emitMatches causes evalFollowedBy to emit matches - positions at which 135 // the left argument is found separated from the right argument by the right 136 // width. 137 emitMatches emitMode = 1 << iota 138 // emitLeftUnmatched causes evalFollowedBy to emit places at which the left 139 // arm doesn't match. 140 emitLeftUnmatched 141 // emitRightUnmatched causes evalFollowedBy to emit places at which the right 142 // arm doesn't match. 143 emitRightUnmatched 144 ) 145 146 // evalFollowedBy handles evaluating a followed by operator. It needs 147 // information about the positions at which the left and right arms of the 148 // followed by operator matches, as well as the offsets for each of the arms: 149 // the number of lexemes apart each of the matches were. 150 // the emitMode controls the output - see the comments on each of the emitMode 151 // values for details. 152 // This function is a little bit confusing, because it's operating on two 153 // input position sets, and not directly on search terms. Its job is to do set 154 // operations on the input sets, depending on emitMode - an intersection or 155 // difference depending on the desired outcome by evalWithinFollowedBy. 156 // This code tries to follow the Postgres implementation in 157 // src/backend/utils/adt/tsvector_op.c. 158 func (e *tsEvaluator) evalFollowedBy( 159 lPositions, rPositions tsPositionSet, lOffset, rOffset int, emitMode emitMode, 160 ) (tsPositionSet, error) { 161 // Followed by makes sure that two terms are separated by exactly n words. 162 // First, find all slots that match for the left expression. 163 164 // Find the offsetted intersection of 2 sorted integer lists, using the 165 // followedN as the offset. 166 var ret tsPositionSet 167 var lIdx, rIdx int 168 // Loop through the two sorted position lists, until the position on the 169 // right is as least as large as the position on the left. 170 for { 171 lExhausted := lIdx >= len(lPositions.positions) 172 rExhausted := rIdx >= len(rPositions.positions) 173 if lExhausted && rExhausted { 174 break 175 } 176 var lPos, rPos int 177 if !lExhausted { 178 lPos = int(lPositions.positions[lIdx].position) + lOffset 179 } else { 180 // Quit unless we're outputting all of the RHS, which we will if we have 181 // a negative match on the LHS. 182 if emitMode&emitRightUnmatched == 0 { 183 break 184 } 185 lPos = math.MaxInt64 186 } 187 if !rExhausted { 188 rPos = int(rPositions.positions[rIdx].position) + rOffset 189 } else { 190 // Quit unless we're outputting all of the LHS, which we will if we have 191 // a negative match on the RHS. 192 if emitMode&emitLeftUnmatched == 0 { 193 break 194 } 195 rPos = math.MaxInt64 196 } 197 198 if lPos < rPos { 199 if emitMode&emitLeftUnmatched > 0 { 200 ret.positions = append(ret.positions, tsPosition{position: uint16(lPos)}) 201 } 202 lIdx++ 203 } else if lPos == rPos { 204 if emitMode&emitMatches > 0 { 205 ret.positions = append(ret.positions, tsPosition{position: uint16(rPos)}) 206 } 207 lIdx++ 208 rIdx++ 209 } else { 210 if emitMode&emitRightUnmatched > 0 { 211 ret.positions = append(ret.positions, tsPosition{position: uint16(rPos)}) 212 } 213 rIdx++ 214 } 215 } 216 if len(ret.positions) > 0 { 217 ret.res = true 218 } 219 return ret, nil 220 } 221 222 // evalWithinFollowedBy is the evaluator for subexpressions of a followed by 223 // operator. Instead of just returning true or false, and possibly short 224 // circuiting on boolean ops, we need to return all of the tspositions at which 225 // each arm of the followed by expression matches. 226 func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tsPositionSet, error) { 227 switch node.op { 228 case invalid: 229 // We're evaluating a leaf (a term). 230 targetWeight := weightAny 231 prefixMatch := false 232 if len(node.term.positions) > 0 { 233 targetWeight = node.term.positions[0].weight 234 if targetWeight&weightStar > 0 { 235 prefixMatch = true 236 // Unset the prefix match. 237 targetWeight = node.term.positions[0].weight & ^weightStar 238 } 239 if targetWeight == 0 { 240 targetWeight = weightAny 241 } 242 } 243 244 // To evaluate a term, we search the vector for a match. 245 target := node.term.lexeme 246 i := sort.Search(len(e.v), func(i int) bool { 247 return e.v[i].lexeme >= target 248 }) 249 if i >= len(e.v) { 250 // No match. 251 return tsPositionSet{}, nil 252 } 253 var ret []tsPosition 254 noPos := false 255 if prefixMatch { 256 for j := i; j < len(e.v); j++ { 257 t := e.v[j] 258 if !strings.HasPrefix(t.lexeme, target) { 259 break 260 } 261 if len(t.positions) == 0 { 262 noPos = true 263 } 264 ret = append(ret, t.positions...) 265 } 266 ret = sortAndUniqTSPositions(ret) 267 ret = filterPositionsByWeight(ret, targetWeight) 268 return tsPositionSet{positions: ret, res: len(ret) > 0, noPos: noPos}, nil 269 } else if e.v[i].lexeme != target { 270 // No match. 271 return tsPositionSet{}, nil 272 } 273 // Return all of the positions at which the term is present and matches the 274 // input weights. 275 positions := filterPositionsByWeight(e.v[i].positions, targetWeight) 276 return tsPositionSet{positions: positions, res: len(positions) > 0, noPos: len(e.v[i].positions) == 0}, nil 277 case or: 278 var lOffset, rOffset, width int 279 280 lPositions, err := e.evalWithinFollowedBy(node.l) 281 if err != nil { 282 return tsPositionSet{}, err 283 } 284 rPositions, err := e.evalWithinFollowedBy(node.r) 285 if err != nil { 286 return tsPositionSet{}, err 287 } 288 if !lPositions.res && !rPositions.res { 289 return tsPositionSet{}, nil 290 } 291 if lPositions.noPos || rPositions.noPos { 292 // Still no position information. 293 return tsPositionSet{noPos: true}, nil 294 } 295 if !lPositions.res { 296 lPositions.positions = nil 297 } 298 if !rPositions.res { 299 rPositions.positions = nil 300 } 301 302 width = lPositions.width 303 if rPositions.width > width { 304 width = rPositions.width 305 } 306 lOffset = width - lPositions.width 307 rOffset = width - rPositions.width 308 309 mode := emitMatches | emitLeftUnmatched | emitRightUnmatched 310 invertResults := false 311 switch { 312 case lPositions.invert && rPositions.invert: 313 invertResults = true 314 mode = emitMatches 315 case lPositions.invert: 316 invertResults = true 317 mode = emitLeftUnmatched 318 case rPositions.invert: 319 invertResults = true 320 mode = emitRightUnmatched 321 } 322 ret, err := e.evalFollowedBy(lPositions, rPositions, lOffset, rOffset, mode) 323 if invertResults { 324 ret.invert = true 325 ret.res = true 326 } 327 ret.width = width 328 return ret, err 329 case not: 330 ret, err := e.evalWithinFollowedBy(node.l) 331 if err != nil { 332 return tsPositionSet{}, err 333 } 334 if ret.res { 335 if len(ret.positions) > 0 { 336 ret.invert = !ret.invert 337 ret.res = true 338 } else if ret.invert { 339 ret.invert = false 340 ret.res = false 341 } 342 } else if ret.noPos { 343 // We still have no position information, so just propagate. 344 return ret, nil 345 } else { 346 ret.invert = true 347 ret.res = true 348 } 349 return ret, nil 350 case followedby: 351 // Followed by and and have similar handling. 352 fallthrough 353 case and: 354 var lOffset, rOffset, width int 355 356 lPositions, err := e.evalWithinFollowedBy(node.l) 357 if err != nil || !lPositions.res { 358 return tsPositionSet{}, err 359 } 360 rPositions, err := e.evalWithinFollowedBy(node.r) 361 if err != nil || !rPositions.res { 362 return tsPositionSet{}, err 363 } 364 if lPositions.noPos || rPositions.noPos { 365 // Still no position information. 366 return tsPositionSet{noPos: true}, nil 367 } 368 if node.op == followedby { 369 lOffset = int(node.followedN) + rPositions.width 370 width = lOffset + lPositions.width 371 } else { 372 width = lPositions.width 373 if rPositions.width > width { 374 width = rPositions.width 375 } 376 lOffset = width - lPositions.width 377 rOffset = width - rPositions.width 378 } 379 380 mode := emitMatches 381 invertResults := false 382 switch { 383 case lPositions.invert && rPositions.invert: 384 invertResults = true 385 mode |= emitLeftUnmatched | emitRightUnmatched 386 case lPositions.invert: 387 mode = emitRightUnmatched 388 case rPositions.invert: 389 mode = emitLeftUnmatched 390 } 391 ret, err := e.evalFollowedBy(lPositions, rPositions, lOffset, rOffset, mode) 392 if invertResults { 393 ret.res = true 394 ret.invert = true 395 } 396 ret.width = width 397 return ret, err 398 } 399 return tsPositionSet{}, errors.AssertionFailedf("invalid operator %d", node.op) 400 } 401 402 func filterPositionsByWeight(positions []tsPosition, weight tsWeight) []tsPosition { 403 if weight == weightAny { 404 return positions 405 } 406 var i int 407 var pos tsPosition 408 var filtered = false 409 for i, pos = range positions { 410 // If we filter anything out, copy into a new return slice. 411 if !pos.weight.matches(weight) { 412 filtered = true 413 break 414 } 415 } 416 if !filtered { 417 return positions 418 } 419 ret := make([]tsPosition, i, len(positions)-1) 420 copy(ret, positions[:i]) 421 // Skip the entry we know doesn't match. 422 i += 1 423 for ; i < len(positions); i++ { 424 pos = positions[i] 425 // Filter the rest of the list. 426 if pos.weight.matches(weight) { 427 ret = append(ret, pos) 428 } 429 } 430 return ret 431 } 432 433 // sortAndUniqTSPositions sorts and uniquifies the input tsPosition list by 434 // their position attributes. 435 func sortAndUniqTSPositions(pos []tsPosition) []tsPosition { 436 if len(pos) <= 1 { 437 return pos 438 } 439 sort.Slice(pos, func(i, j int) bool { 440 return pos[i].position < pos[j].position 441 }) 442 // Then distinct: (wouldn't it be nice if Go had generics?) 443 lastUniqueIdx := 0 444 for j := 1; j < len(pos); j++ { 445 if pos[j].position != pos[lastUniqueIdx].position { 446 // We found a unique entry, at index i. The last unique entry in the array 447 // was at lastUniqueIdx, so set the entry after that one to our new unique 448 // entry, and bump lastUniqueIdx for the next loop iteration. 449 lastUniqueIdx++ 450 pos[lastUniqueIdx] = pos[j] 451 } 452 } 453 pos = pos[:lastUniqueIdx+1] 454 if len(pos) > maxTSVectorPositions { 455 // Postgres silently truncates position lists to length 256. 456 pos = pos[:maxTSVectorPositions] 457 } 458 return pos 459 }