github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/tsvector.go (about)

     1  // Copyright 2022 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tsearch
    12  
    13  import (
    14  	"math/bits"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"unicode"
    19  	"unicode/utf8"
    20  
    21  	"github.com/blevesearch/snowballstem"
    22  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode"
    23  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror"
    24  	"github.com/cockroachdb/errors"
    25  )
    26  
    27  // This file defines the TSVector data structure, which is used to implement
    28  // Postgres's tsvector text search mechanism.
    29  // See https://www.postgresql.org/docs/current/datatype-textsearch.html for
    30  // context on what each of the pieces do.
    31  //
    32  // TSVector is ultimately used to represent a document as a posting list - a
    33  // list of lexemes in the doc (typically without stop words like common or short
    34  // words, and typically stemmed using a stemming algorithm like snowball
    35  // stemming (https://snowballstem.org/)), along with an associated set of
    36  // positions that those lexemes occur within the document.
    37  //
    38  // Typically, this posting list is then stored in an inverted index within the
    39  // database to accelerate searches of terms within the document.
    40  //
    41  // The key structures are:
    42  // - tsTerm is a document term (also referred to as lexeme) along with a
    43  //   position list, which contains the positions within a document that a term
    44  //   appeared, along with an optional weight, which controls matching.
    45  // - tsTerm is also used during parsing of both TSQueries and TSVectors, so a
    46  //   tsTerm also can represent an TSQuery operator.
    47  // - tsWeight represents the weight of a given lexeme. It's also used for
    48  //   queries, when a "star" weight is available that matches any weight.
    49  // - TSVector is a list of tsTerms, ordered by their lexeme.
    50  
    51  // tsWeight is a bitfield that represents the weight of a given term. When
    52  // stored in a TSVector, only 1 of the bits will be set. The default weight is
    53  // D - as a result, we store 0 for the weight of terms with weight D or no
    54  // specified weight. The weightStar value is never set in a TSVector weight.
    55  //
    56  // tsWeight is also used inside of TSQueries, to specify the weight to search.
    57  // Within TSQueries, the absence of a weight is the default, and indicates that
    58  // the search term should match any matching term, regardless of its weight. If
    59  // one or more of the weights are set in a search term, it indicates that the
    60  // query should match only terms with the given weights.
    61  type tsWeight byte
    62  
    63  const (
    64  	// These enum values are a bitfield and must be kept in order.
    65  	weightD tsWeight = 1 << iota
    66  	weightC
    67  	weightB
    68  	weightA
    69  	// weightStar is a special "weight" that can be specified only in a search
    70  	// term. It indicates prefix matching, which will allow the term to match any
    71  	// document term that begins with the search term.
    72  	weightStar
    73  	invalidWeight
    74  
    75  	weightAny = weightA | weightB | weightC | weightD
    76  )
    77  
    78  // NB: must be kept in sync with stringSize().
    79  func (w tsWeight) writeString(buf *strings.Builder) {
    80  	if w&weightStar != 0 {
    81  		buf.WriteByte('*')
    82  	}
    83  	if w&weightA != 0 {
    84  		buf.WriteByte('A')
    85  	}
    86  	if w&weightB != 0 {
    87  		buf.WriteByte('B')
    88  	}
    89  	if w&weightC != 0 {
    90  		buf.WriteByte('C')
    91  	}
    92  	if w&weightD != 0 {
    93  		buf.WriteByte('D')
    94  	}
    95  }
    96  
    97  // stringSize returns the length of the string that corresponds to this
    98  // tsWeight.
    99  // NB: must be kept in sync with writeString().
   100  func (w tsWeight) stringSize() int {
   101  	// Count the number of bits set in the lowest 5 bits.
   102  	return bits.OnesCount8(uint8(w & 31))
   103  }
   104  
   105  // TSVectorPGEncoding returns the PG-compatible wire protocol encoding for a
   106  // given weight. Note that this is only allowable for TSVector tsweights, which
   107  // can't have more than one weight set at the same time. In a TSQuery, you might
   108  // have more than one weight per lexeme, which is not encodable using this
   109  // scheme.
   110  func (w tsWeight) TSVectorPGEncoding() (byte, error) {
   111  	switch w {
   112  	case weightA:
   113  		return 3, nil
   114  	case weightB:
   115  		return 2, nil
   116  	case weightC:
   117  		return 1, nil
   118  	case weightD, 0:
   119  		return 0, nil
   120  	}
   121  	return 0, errors.Errorf("invalid tsvector weight %d", w)
   122  }
   123  
   124  func (w tsWeight) val() int {
   125  	b, err := w.TSVectorPGEncoding()
   126  	if err != nil {
   127  		panic(err)
   128  	}
   129  	return int(b)
   130  }
   131  
   132  // matches returns true if the receiver is matched by the input tsquery weight.
   133  func (w tsWeight) matches(queryWeight tsWeight) bool {
   134  	if queryWeight == weightAny {
   135  		return true
   136  	}
   137  	if w&queryWeight > 0 {
   138  		return true
   139  	}
   140  	// If we're querying for D, and the receiver has no weight, that's also a
   141  	// match.
   142  	return queryWeight&weightD > 0 && w == 0
   143  }
   144  
   145  func tsWeightFromVectorPGEncoding(b byte) (tsWeight, error) {
   146  	switch b {
   147  	case 3:
   148  		return weightA, nil
   149  	case 2:
   150  		return weightB, nil
   151  	case 1:
   152  		return weightC, nil
   153  	case 0:
   154  		// We don't explicitly return weightD, since it's the default.
   155  		return 0, nil
   156  	}
   157  	return 0, errors.Errorf("invalid encoded tsvector weight %d", b)
   158  }
   159  
   160  // tsPosition is a position within a document, along with an optional weight.
   161  type tsPosition struct {
   162  	position uint16
   163  	weight   tsWeight
   164  }
   165  
   166  // tsTerm is either a lexeme and position list, or an operator (when parsing a
   167  // a TSQuery).
   168  type tsTerm struct {
   169  	// lexeme is at most 2046 characters.
   170  	lexeme    string
   171  	positions []tsPosition
   172  
   173  	// The operator and followedN fields are only used when parsing a TSQuery.
   174  	operator tsOperator
   175  	// Set only when operator = followedby
   176  	// At most 16384.
   177  	followedN uint16
   178  }
   179  
   180  func newLexemeTerm(lexeme string) (tsTerm, error) {
   181  	if len(lexeme) > 2046 {
   182  		return tsTerm{}, pgerror.Newf(pgcode.ProgramLimitExceeded, "word is too long (%d bytes, max 2046 bytes)", len(lexeme))
   183  	}
   184  	return tsTerm{lexeme: lexeme}, nil
   185  }
   186  
   187  // NB: must be kept in sync with stringSize().
   188  func (t tsTerm) writeString(buf *strings.Builder) {
   189  	if t.operator != 0 {
   190  		switch t.operator {
   191  		case and:
   192  			buf.WriteString("&")
   193  			return
   194  		case or:
   195  			buf.WriteString("|")
   196  			return
   197  		case not:
   198  			buf.WriteString("!")
   199  			return
   200  		case lparen:
   201  			buf.WriteString("(")
   202  			return
   203  		case rparen:
   204  			buf.WriteString(")")
   205  			return
   206  		case followedby:
   207  			buf.WriteString("<")
   208  			if t.followedN == 1 {
   209  				buf.WriteString("-")
   210  			} else {
   211  				buf.WriteString(strconv.Itoa(int(t.followedN)))
   212  			}
   213  			buf.WriteString(">")
   214  			return
   215  		}
   216  	}
   217  
   218  	buf.WriteByte('\'')
   219  	for _, r := range t.lexeme {
   220  		if r == '\'' {
   221  			// Single quotes are escaped as double single quotes inside of a TSVector.
   222  			buf.WriteString(`''`)
   223  		} else {
   224  			buf.WriteRune(r)
   225  		}
   226  	}
   227  	buf.WriteByte('\'')
   228  	for i, pos := range t.positions {
   229  		if i > 0 {
   230  			buf.WriteByte(',')
   231  		} else {
   232  			buf.WriteByte(':')
   233  		}
   234  		if pos.position > 0 {
   235  			buf.WriteString(strconv.Itoa(int(pos.position)))
   236  		}
   237  		pos.weight.writeString(buf)
   238  	}
   239  }
   240  
   241  // stringSize returns the length of the string representation of this tsTerm.
   242  // NB: must be kept in sync with writeString().
   243  func (t tsTerm) stringSize() int {
   244  	if t.operator != 0 {
   245  		switch t.operator {
   246  		case and, or, not, lparen, rparen:
   247  			return 1
   248  		case followedby:
   249  			if t.followedN == 1 {
   250  				return 3 // '<->'
   251  			}
   252  			return 2 + len(strconv.Itoa(int(t.followedN))) // fmt.Sprintf("<%d>", t.followedN)
   253  		}
   254  	}
   255  	size := 1 // '\''
   256  	for _, r := range t.lexeme {
   257  		if r == '\'' {
   258  			// Single quotes are escaped as double single quotes inside of a
   259  			// TSVector.
   260  			size += 2
   261  		} else {
   262  			// Compare as uint32 to correctly handle negative runes.
   263  			if uint32(r) < utf8.RuneSelf {
   264  				size++
   265  			} else {
   266  				size += utf8.RuneLen(r)
   267  			}
   268  		}
   269  	}
   270  	size++                   // '\''
   271  	size += len(t.positions) // ':' or ',' for each position
   272  	for _, pos := range t.positions {
   273  		if pos.position > 0 {
   274  			size += len(strconv.Itoa(int(pos.position)))
   275  		}
   276  		size += pos.weight.stringSize()
   277  	}
   278  	return size
   279  }
   280  
   281  func (t tsTerm) matchesWeight(targetWeight tsWeight) bool {
   282  	if targetWeight == weightAny {
   283  		return true
   284  	}
   285  	if len(t.positions) == 0 {
   286  		// A "stripped" tsvector (no associated positions) always matches any input
   287  		// weight.
   288  		return true
   289  	}
   290  	for _, pos := range t.positions {
   291  		if pos.weight.matches(targetWeight) {
   292  			return true
   293  		}
   294  	}
   295  	return false
   296  }
   297  
   298  func (t tsTerm) isPrefixMatch() bool {
   299  	return len(t.positions) >= 1 && t.positions[0].weight&weightStar != 0
   300  }
   301  
   302  // TSVector is a sorted list of terms, each of which is a lexeme that might have
   303  // an associated position within an original document.
   304  type TSVector []tsTerm
   305  
   306  func (t TSVector) String() string {
   307  	var buf strings.Builder
   308  	for i, term := range t {
   309  		if i > 0 {
   310  			buf.WriteByte(' ')
   311  		}
   312  		term.writeString(&buf)
   313  	}
   314  	return buf.String()
   315  }
   316  
   317  // StringSize returns the length of the string that would have been returned on
   318  // String() call, without actually constructing that string.
   319  func (t TSVector) StringSize() int {
   320  	var size int
   321  	if len(t) > 0 {
   322  		size = len(t) - 1 // space
   323  	}
   324  	for _, term := range t {
   325  		size += term.stringSize()
   326  	}
   327  	return size
   328  }
   329  
   330  // ParseTSVector produces a TSVector from an input string. The input will be
   331  // sorted by lexeme, but will not be automatically stemmed or stop-worded.
   332  func ParseTSVector(input string) (TSVector, error) {
   333  	parser := tsVectorLexer{
   334  		input: input,
   335  		state: expectingTerm,
   336  	}
   337  	ret, err := parser.lex()
   338  	if err != nil {
   339  		return ret, err
   340  	}
   341  
   342  	return normalizeTSVector(ret)
   343  }
   344  
   345  func normalizeTSVector(ret TSVector) (TSVector, error) {
   346  	if len(ret) > 1 {
   347  		// Sort and de-duplicate the resultant TSVector.
   348  		sort.Slice(ret, func(i, j int) bool {
   349  			return ret[i].lexeme < ret[j].lexeme
   350  		})
   351  		// Then distinct: (wouldn't it be nice if Go had generics?)
   352  		lastUniqueIdx := 0
   353  		for j := 1; j < len(ret); j++ {
   354  			if ret[j].lexeme != ret[lastUniqueIdx].lexeme {
   355  				// We found a unique entry, at index i. The last unique entry in the
   356  				// array was at lastUniqueIdx, so set the entry after that one to our
   357  				// new unique entry, and bump lastUniqueIdx for the next loop iteration.
   358  				// First, sort and unique the position list now that we've collapsed all
   359  				// of the identical lexemes.
   360  				ret[lastUniqueIdx].positions = sortAndUniqTSPositions(ret[lastUniqueIdx].positions)
   361  				lastUniqueIdx++
   362  				ret[lastUniqueIdx] = ret[j]
   363  			} else {
   364  				// The last entries were not unique. Collapse their positions into the
   365  				// first entry's list.
   366  				ret[lastUniqueIdx].positions = append(ret[lastUniqueIdx].positions, ret[j].positions...)
   367  			}
   368  		}
   369  		ret = ret[:lastUniqueIdx+1]
   370  	}
   371  	if len(ret) >= 1 {
   372  		// Make sure to sort and uniq the position list even if there's only 1
   373  		// entry.
   374  		lastIdx := len(ret) - 1
   375  		ret[lastIdx].positions = sortAndUniqTSPositions(ret[lastIdx].positions)
   376  	}
   377  	return ret, nil
   378  }
   379  
   380  var validCharTables = []*unicode.RangeTable{unicode.Letter, unicode.Number}
   381  
   382  // TSParse is the function that splits an input text into a list of
   383  // tokens. For now, the parser that we use is very simple: it merely lowercases
   384  // the input and splits it into tokens based on assuming that non-letter,
   385  // non-number characters are whitespace.
   386  //
   387  // The Postgres text search parser is much, much more sophisticated. The
   388  // documentation (https://www.postgresql.org/docs/current/textsearch-parsers.html)
   389  // gives more information, but roughly, each token is categorized into one of
   390  // about 20 different buckets, such as asciiword, url, email, host, float, int,
   391  // version, tag, etc. It uses very specific rules to produce these outputs.
   392  // Another interesting transformation is returning multiple tokens for a
   393  // hyphenated word, including a token that represents the entire hyphenated word,
   394  // as well as one for each of the hyphenated components.
   395  //
   396  // It's not clear whether we need to exactly mimic this functionality. Likely,
   397  // we will eventually want to do this.
   398  func TSParse(input string) []string {
   399  	return strings.FieldsFunc(input, func(r rune) bool {
   400  		return !unicode.IsOneOf(validCharTables, r)
   401  	})
   402  }
   403  
   404  // TSLexize implements the "dictionary" construct that's exposed via ts_lexize.
   405  // It gets invoked once per input token to produce an output lexeme during
   406  // routines like to_tsvector and to_tsquery.
   407  // It can return true in the second parameter to indicate a stopword was found.
   408  func TSLexize(config string, token string) (lexeme string, stopWord bool, err error) {
   409  	stopwords, ok := stopwordsMap[config]
   410  	if !ok {
   411  		return "", false, pgerror.Newf(pgcode.UndefinedObject, "text search configuration %q does not exist", config)
   412  	}
   413  
   414  	lower := strings.ToLower(token)
   415  	if _, ok := stopwords[lower]; ok {
   416  		return "", true, nil
   417  	}
   418  	stemmer, err := getStemmer(config)
   419  	if err != nil {
   420  		return "", false, err
   421  	}
   422  	env := snowballstem.NewEnv(lower)
   423  	stemmer(env)
   424  	return env.Current(), false, nil
   425  }
   426  
   427  // DocumentToTSVector parses an input document into lexemes, removes stop words,
   428  // stems and normalizes the lexemes, and returns a TSVector annotated with
   429  // lexeme positions according to a text search configuration passed by name.
   430  func DocumentToTSVector(config string, input string) (TSVector, error) {
   431  	tokens := TSParse(input)
   432  	vector := make(TSVector, 0, len(tokens))
   433  	for i := range tokens {
   434  		lexeme, stopWord, err := TSLexize(config, tokens[i])
   435  		if err != nil {
   436  			return nil, err
   437  		}
   438  		if stopWord {
   439  			continue
   440  		}
   441  
   442  		term := tsTerm{lexeme: lexeme}
   443  		pos := i + 1
   444  		if i > maxTSVectorPosition {
   445  			// Postgres silently truncates positions larger than 16383 to 16383.
   446  			pos = maxTSVectorPosition
   447  		}
   448  		term.positions = []tsPosition{{position: uint16(pos)}}
   449  		vector = append(vector, term)
   450  	}
   451  	return normalizeTSVector(vector)
   452  }