github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/eval.go (about)

     1  // Copyright 2022 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tsearch
    12  
    13  import (
    14  	"math"
    15  	"sort"
    16  	"strings"
    17  
    18  	"github.com/cockroachdb/errors"
    19  )
    20  
    21  // EvalTSQuery runs the provided TSQuery against the provided TSVector,
    22  // returning whether or not the query matches the vector.
    23  func EvalTSQuery(q TSQuery, v TSVector) (bool, error) {
    24  	evaluator := tsEvaluator{
    25  		v: v,
    26  		q: q,
    27  	}
    28  	return evaluator.eval()
    29  }
    30  
    31  type tsEvaluator struct {
    32  	v TSVector
    33  	q TSQuery
    34  }
    35  
    36  func (e *tsEvaluator) eval() (bool, error) {
    37  	return e.evalNode(e.q.root)
    38  }
    39  
    40  // evalNode is used to evaluate a query node that's not nested within any
    41  // followed by operators. it returns true if the match was successful.
    42  func (e *tsEvaluator) evalNode(node *tsNode) (bool, error) {
    43  	switch node.op {
    44  	case invalid:
    45  		// If there's no operator we're evaluating a leaf term.
    46  		prefixMatch := false
    47  		targetWeight := weightAny
    48  		if len(node.term.positions) > 0 {
    49  			targetWeight = node.term.positions[0].weight
    50  			if targetWeight&weightStar > 0 {
    51  				prefixMatch = true
    52  				// Unset the prefix match.
    53  				targetWeight = node.term.positions[0].weight & ^weightStar
    54  			}
    55  			// If no flags are set we can match anything.
    56  			if targetWeight == 0 {
    57  				targetWeight = weightAny
    58  			}
    59  		}
    60  
    61  		// To evaluate a term, we search the vector for a match.
    62  		target := node.term.lexeme
    63  		i := sort.Search(len(e.v), func(i int) bool {
    64  			return e.v[i].lexeme >= target
    65  		})
    66  		if !prefixMatch && i < len(e.v) {
    67  			return e.v[i].lexeme == target && e.v[i].matchesWeight(targetWeight), nil
    68  		}
    69  		for ; i < len(e.v); i++ {
    70  			t := e.v[i]
    71  			// If we're prefix matching, continue searching until we either run out
    72  			// of prefix matches or find one that matches the weight in question.
    73  			if !strings.HasPrefix(t.lexeme, target) {
    74  				break
    75  			}
    76  			if t.matchesWeight(targetWeight) {
    77  				return true, nil
    78  			}
    79  		}
    80  		return false, nil
    81  	case and:
    82  		// Match if both operands are true.
    83  		l, err := e.evalNode(node.l)
    84  		if err != nil || !l {
    85  			return false, err
    86  		}
    87  		return e.evalNode(node.r)
    88  	case or:
    89  		// Match if either operand is true.
    90  		l, err := e.evalNode(node.l)
    91  		if err != nil || l {
    92  			return l, err
    93  		}
    94  		return e.evalNode(node.r)
    95  	case not:
    96  		// Match if the operand is false.
    97  		ret, err := e.evalNode(node.l)
    98  		return !ret, err
    99  	case followedby:
   100  		// For followed-by queries, we recurse into the special followed-by handler.
   101  		// Then, we return true if there is at least one position at which the
   102  		// followed-by query matches.
   103  		positions, err := e.evalWithinFollowedBy(node)
   104  		return positions.res, err
   105  	}
   106  	return false, errors.AssertionFailedf("invalid operator %d", node.op)
   107  }
   108  
   109  // tsPositionSet keeps track of metadata for a followed-by match. It's used to
   110  // pass information about followed by queries during evaluation of them.
   111  type tsPositionSet struct {
   112  	// positions is the list of positions that the match is successful at (or,
   113  	// if invert is true, unsuccessful at).
   114  	positions []tsPosition
   115  	// width is the width of the match. This is important to track to deal with
   116  	// chained followed by queries with possibly different widths (<-> vs <2> etc).
   117  	// A match of a single term within a followed by has width 0.
   118  	width int
   119  	// invert, if true, indicates that this match should be inverted. It's used
   120  	// to handle followed by matches within not operators.
   121  	invert bool
   122  
   123  	// res indicates that this match found positive results.
   124  	res bool
   125  
   126  	// noPos indicates that this match was missing position information.
   127  	noPos bool
   128  }
   129  
   130  // emitMode is a bitfield that controls the output of followed by matches.
   131  type emitMode int
   132  
   133  const (
   134  	// emitMatches causes evalFollowedBy to emit matches - positions at which
   135  	// the left argument is found separated from the right argument by the right
   136  	// width.
   137  	emitMatches emitMode = 1 << iota
   138  	// emitLeftUnmatched causes evalFollowedBy to emit places at which the left
   139  	// arm doesn't match.
   140  	emitLeftUnmatched
   141  	// emitRightUnmatched causes evalFollowedBy to emit places at which the right
   142  	// arm doesn't match.
   143  	emitRightUnmatched
   144  )
   145  
   146  // evalFollowedBy handles evaluating a followed by operator. It needs
   147  // information about the positions at which the left and right arms of the
   148  // followed by operator matches, as well as the offsets for each of the arms:
   149  // the number of lexemes apart each of the matches were.
   150  // the emitMode controls the output - see the comments on each of the emitMode
   151  // values for details.
   152  // This function is a little bit confusing, because it's operating on two
   153  // input position sets, and not directly on search terms. Its job is to do set
   154  // operations on the input sets, depending on emitMode - an intersection or
   155  // difference depending on the desired outcome by evalWithinFollowedBy.
   156  // This code tries to follow the Postgres implementation in
   157  // src/backend/utils/adt/tsvector_op.c.
   158  func (e *tsEvaluator) evalFollowedBy(
   159  	lPositions, rPositions tsPositionSet, lOffset, rOffset int, emitMode emitMode,
   160  ) (tsPositionSet, error) {
   161  	// Followed by makes sure that two terms are separated by exactly n words.
   162  	// First, find all slots that match for the left expression.
   163  
   164  	// Find the offsetted intersection of 2 sorted integer lists, using the
   165  	// followedN as the offset.
   166  	var ret tsPositionSet
   167  	var lIdx, rIdx int
   168  	// Loop through the two sorted position lists, until the position on the
   169  	// right is as least as large as the position on the left.
   170  	for {
   171  		lExhausted := lIdx >= len(lPositions.positions)
   172  		rExhausted := rIdx >= len(rPositions.positions)
   173  		if lExhausted && rExhausted {
   174  			break
   175  		}
   176  		var lPos, rPos int
   177  		if !lExhausted {
   178  			lPos = int(lPositions.positions[lIdx].position) + lOffset
   179  		} else {
   180  			// Quit unless we're outputting all of the RHS, which we will if we have
   181  			// a negative match on the LHS.
   182  			if emitMode&emitRightUnmatched == 0 {
   183  				break
   184  			}
   185  			lPos = math.MaxInt64
   186  		}
   187  		if !rExhausted {
   188  			rPos = int(rPositions.positions[rIdx].position) + rOffset
   189  		} else {
   190  			// Quit unless we're outputting all of the LHS, which we will if we have
   191  			// a negative match on the RHS.
   192  			if emitMode&emitLeftUnmatched == 0 {
   193  				break
   194  			}
   195  			rPos = math.MaxInt64
   196  		}
   197  
   198  		if lPos < rPos {
   199  			if emitMode&emitLeftUnmatched > 0 {
   200  				ret.positions = append(ret.positions, tsPosition{position: uint16(lPos)})
   201  			}
   202  			lIdx++
   203  		} else if lPos == rPos {
   204  			if emitMode&emitMatches > 0 {
   205  				ret.positions = append(ret.positions, tsPosition{position: uint16(rPos)})
   206  			}
   207  			lIdx++
   208  			rIdx++
   209  		} else {
   210  			if emitMode&emitRightUnmatched > 0 {
   211  				ret.positions = append(ret.positions, tsPosition{position: uint16(rPos)})
   212  			}
   213  			rIdx++
   214  		}
   215  	}
   216  	if len(ret.positions) > 0 {
   217  		ret.res = true
   218  	}
   219  	return ret, nil
   220  }
   221  
   222  // evalWithinFollowedBy is the evaluator for subexpressions of a followed by
   223  // operator. Instead of just returning true or false, and possibly short
   224  // circuiting on boolean ops, we need to return all of the tspositions at which
   225  // each arm of the followed by expression matches.
   226  func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tsPositionSet, error) {
   227  	switch node.op {
   228  	case invalid:
   229  		// We're evaluating a leaf (a term).
   230  		targetWeight := weightAny
   231  		prefixMatch := false
   232  		if len(node.term.positions) > 0 {
   233  			targetWeight = node.term.positions[0].weight
   234  			if targetWeight&weightStar > 0 {
   235  				prefixMatch = true
   236  				// Unset the prefix match.
   237  				targetWeight = node.term.positions[0].weight & ^weightStar
   238  			}
   239  			if targetWeight == 0 {
   240  				targetWeight = weightAny
   241  			}
   242  		}
   243  
   244  		// To evaluate a term, we search the vector for a match.
   245  		target := node.term.lexeme
   246  		i := sort.Search(len(e.v), func(i int) bool {
   247  			return e.v[i].lexeme >= target
   248  		})
   249  		if i >= len(e.v) {
   250  			// No match.
   251  			return tsPositionSet{}, nil
   252  		}
   253  		var ret []tsPosition
   254  		noPos := false
   255  		if prefixMatch {
   256  			for j := i; j < len(e.v); j++ {
   257  				t := e.v[j]
   258  				if !strings.HasPrefix(t.lexeme, target) {
   259  					break
   260  				}
   261  				if len(t.positions) == 0 {
   262  					noPos = true
   263  				}
   264  				ret = append(ret, t.positions...)
   265  			}
   266  			ret = sortAndUniqTSPositions(ret)
   267  			ret = filterPositionsByWeight(ret, targetWeight)
   268  			return tsPositionSet{positions: ret, res: len(ret) > 0, noPos: noPos}, nil
   269  		} else if e.v[i].lexeme != target {
   270  			// No match.
   271  			return tsPositionSet{}, nil
   272  		}
   273  		// Return all of the positions at which the term is present and matches the
   274  		// input weights.
   275  		positions := filterPositionsByWeight(e.v[i].positions, targetWeight)
   276  		return tsPositionSet{positions: positions, res: len(positions) > 0, noPos: len(e.v[i].positions) == 0}, nil
   277  	case or:
   278  		var lOffset, rOffset, width int
   279  
   280  		lPositions, err := e.evalWithinFollowedBy(node.l)
   281  		if err != nil {
   282  			return tsPositionSet{}, err
   283  		}
   284  		rPositions, err := e.evalWithinFollowedBy(node.r)
   285  		if err != nil {
   286  			return tsPositionSet{}, err
   287  		}
   288  		if !lPositions.res && !rPositions.res {
   289  			return tsPositionSet{}, nil
   290  		}
   291  		if lPositions.noPos || rPositions.noPos {
   292  			// Still no position information.
   293  			return tsPositionSet{noPos: true}, nil
   294  		}
   295  		if !lPositions.res {
   296  			lPositions.positions = nil
   297  		}
   298  		if !rPositions.res {
   299  			rPositions.positions = nil
   300  		}
   301  
   302  		width = lPositions.width
   303  		if rPositions.width > width {
   304  			width = rPositions.width
   305  		}
   306  		lOffset = width - lPositions.width
   307  		rOffset = width - rPositions.width
   308  
   309  		mode := emitMatches | emitLeftUnmatched | emitRightUnmatched
   310  		invertResults := false
   311  		switch {
   312  		case lPositions.invert && rPositions.invert:
   313  			invertResults = true
   314  			mode = emitMatches
   315  		case lPositions.invert:
   316  			invertResults = true
   317  			mode = emitLeftUnmatched
   318  		case rPositions.invert:
   319  			invertResults = true
   320  			mode = emitRightUnmatched
   321  		}
   322  		ret, err := e.evalFollowedBy(lPositions, rPositions, lOffset, rOffset, mode)
   323  		if invertResults {
   324  			ret.invert = true
   325  			ret.res = true
   326  		}
   327  		ret.width = width
   328  		return ret, err
   329  	case not:
   330  		ret, err := e.evalWithinFollowedBy(node.l)
   331  		if err != nil {
   332  			return tsPositionSet{}, err
   333  		}
   334  		if ret.res {
   335  			if len(ret.positions) > 0 {
   336  				ret.invert = !ret.invert
   337  				ret.res = true
   338  			} else if ret.invert {
   339  				ret.invert = false
   340  				ret.res = false
   341  			}
   342  		} else if ret.noPos {
   343  			// We still have no position information, so just propagate.
   344  			return ret, nil
   345  		} else {
   346  			ret.invert = true
   347  			ret.res = true
   348  		}
   349  		return ret, nil
   350  	case followedby:
   351  		// Followed by and and have similar handling.
   352  		fallthrough
   353  	case and:
   354  		var lOffset, rOffset, width int
   355  
   356  		lPositions, err := e.evalWithinFollowedBy(node.l)
   357  		if err != nil || !lPositions.res {
   358  			return tsPositionSet{}, err
   359  		}
   360  		rPositions, err := e.evalWithinFollowedBy(node.r)
   361  		if err != nil || !rPositions.res {
   362  			return tsPositionSet{}, err
   363  		}
   364  		if lPositions.noPos || rPositions.noPos {
   365  			// Still no position information.
   366  			return tsPositionSet{noPos: true}, nil
   367  		}
   368  		if node.op == followedby {
   369  			lOffset = int(node.followedN) + rPositions.width
   370  			width = lOffset + lPositions.width
   371  		} else {
   372  			width = lPositions.width
   373  			if rPositions.width > width {
   374  				width = rPositions.width
   375  			}
   376  			lOffset = width - lPositions.width
   377  			rOffset = width - rPositions.width
   378  		}
   379  
   380  		mode := emitMatches
   381  		invertResults := false
   382  		switch {
   383  		case lPositions.invert && rPositions.invert:
   384  			invertResults = true
   385  			mode |= emitLeftUnmatched | emitRightUnmatched
   386  		case lPositions.invert:
   387  			mode = emitRightUnmatched
   388  		case rPositions.invert:
   389  			mode = emitLeftUnmatched
   390  		}
   391  		ret, err := e.evalFollowedBy(lPositions, rPositions, lOffset, rOffset, mode)
   392  		if invertResults {
   393  			ret.res = true
   394  			ret.invert = true
   395  		}
   396  		ret.width = width
   397  		return ret, err
   398  	}
   399  	return tsPositionSet{}, errors.AssertionFailedf("invalid operator %d", node.op)
   400  }
   401  
   402  func filterPositionsByWeight(positions []tsPosition, weight tsWeight) []tsPosition {
   403  	if weight == weightAny {
   404  		return positions
   405  	}
   406  	var i int
   407  	var pos tsPosition
   408  	var filtered = false
   409  	for i, pos = range positions {
   410  		// If we filter anything out, copy into a new return slice.
   411  		if !pos.weight.matches(weight) {
   412  			filtered = true
   413  			break
   414  		}
   415  	}
   416  	if !filtered {
   417  		return positions
   418  	}
   419  	ret := make([]tsPosition, i, len(positions)-1)
   420  	copy(ret, positions[:i])
   421  	// Skip the entry we know doesn't match.
   422  	i += 1
   423  	for ; i < len(positions); i++ {
   424  		pos = positions[i]
   425  		// Filter the rest of the list.
   426  		if pos.weight.matches(weight) {
   427  			ret = append(ret, pos)
   428  		}
   429  	}
   430  	return ret
   431  }
   432  
   433  // sortAndUniqTSPositions sorts and uniquifies the input tsPosition list by
   434  // their position attributes.
   435  func sortAndUniqTSPositions(pos []tsPosition) []tsPosition {
   436  	if len(pos) <= 1 {
   437  		return pos
   438  	}
   439  	sort.Slice(pos, func(i, j int) bool {
   440  		return pos[i].position < pos[j].position
   441  	})
   442  	// Then distinct: (wouldn't it be nice if Go had generics?)
   443  	lastUniqueIdx := 0
   444  	for j := 1; j < len(pos); j++ {
   445  		if pos[j].position != pos[lastUniqueIdx].position {
   446  			// We found a unique entry, at index i. The last unique entry in the array
   447  			// was at lastUniqueIdx, so set the entry after that one to our new unique
   448  			// entry, and bump lastUniqueIdx for the next loop iteration.
   449  			lastUniqueIdx++
   450  			pos[lastUniqueIdx] = pos[j]
   451  		}
   452  	}
   453  	pos = pos[:lastUniqueIdx+1]
   454  	if len(pos) > maxTSVectorPositions {
   455  		// Postgres silently truncates position lists to length 256.
   456  		pos = pos[:maxTSVectorPositions]
   457  	}
   458  	return pos
   459  }