github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/encoding.go (about)

     1  // Copyright 2022 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tsearch
    12  
    13  import (
    14  	"bytes"
    15  
    16  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode"
    17  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror"
    18  	"github.com/cockroachdb/cockroachdb-parser/pkg/util/encoding"
    19  	"github.com/cockroachdb/errors"
    20  )
    21  
    22  // EncodeTSVector encodes a tsvector into a serialized representation for
    23  // on-disk storage.
    24  func EncodeTSVector(appendTo []byte, vector TSVector) ([]byte, error) {
    25  	appendTo = encoding.EncodeUint32Ascending(appendTo, uint32(len(vector)))
    26  	for _, term := range vector {
    27  		l := term.lexeme
    28  		appendTo = encoding.EncodeUntaggedBytesValue(appendTo, encoding.UnsafeConvertStringToBytes(l))
    29  		if len(term.positions) > maxTSVectorPositions {
    30  			return nil, pgerror.Newf(pgcode.ProgramLimitExceeded,
    31  				"tsvector position list of size %d too large (maximum is %d)", len(term.positions),
    32  				maxTSVectorPositions)
    33  		}
    34  		if len(l) > maxTSVectorLexemeLen {
    35  			return nil, pgerror.Newf(pgcode.ProgramLimitExceeded,
    36  				"tsvector lexeme of size %d too large (maximum is %d)", len(l),
    37  				maxTSVectorLexemeLen)
    38  		}
    39  		appendTo = encoding.EncodeUint16Ascending(appendTo, uint16(len(term.positions)))
    40  		for _, pos := range term.positions {
    41  			weight, err := pos.weight.TSVectorPGEncoding()
    42  			if err != nil {
    43  				return nil, err
    44  			}
    45  			// Clear the 2 most significant bits. These should never be set,
    46  			// as we always make sure that positions are at most 1 << 14, but
    47  			// better an extra check.
    48  			position := pos.position & (^(uint16(3) << 14))
    49  			out := position | (uint16(weight) << 14)
    50  			appendTo = encoding.EncodeUint16Ascending(appendTo, out)
    51  		}
    52  	}
    53  	return appendTo, nil
    54  }
    55  
    56  // DecodeTSVector decodes a tsvector in disk-storage representation from the
    57  // input byte slice.
    58  func DecodeTSVector(b []byte) (ret TSVector, err error) {
    59  	var nTerms uint32
    60  	var nPositions, position uint16
    61  	b, nTerms, err = encoding.DecodeUint32Ascending(b)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	ret = make([]tsTerm, nTerms)
    66  	for i := uint32(0); i < nTerms; i++ {
    67  		var lexeme []byte
    68  		b, lexeme, err = encoding.DecodeUntaggedBytesValue(b)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  		b, nPositions, err = encoding.DecodeUint16Ascending(b)
    73  		if err != nil {
    74  			return nil, err
    75  		}
    76  		term := &ret[i]
    77  		term.lexeme = string(lexeme)
    78  		term.positions = make([]tsPosition, nPositions)
    79  		for j := uint16(0); j < nPositions; j++ {
    80  			b, position, err = encoding.DecodeUint16Ascending(b)
    81  			if err != nil {
    82  				return nil, err
    83  			}
    84  			encodedWeight := position >> 14
    85  			weight, err := tsWeightFromVectorPGEncoding(byte(encodedWeight))
    86  			if err != nil {
    87  				return nil, err
    88  			}
    89  			// Clear the 2 most significant bits (they were used for the weight).
    90  			position = position & (^(uint16(3) << 14))
    91  			term.positions[j] = tsPosition{position: position, weight: weight}
    92  		}
    93  	}
    94  	return ret, nil
    95  }
    96  
    97  // EncodeTSVectorPGBinary encodes a tsvector into a serialized representation
    98  // that's identical to Postgres's wire protocol representation.
    99  //
   100  // The below comment explains the wire protocol representation. It is taken from
   101  // this page: https://www.npgsql.org/dev/types.html
   102  //
   103  // tsvector:
   104  //
   105  //	UInt32 number of lexemes
   106  //	for each lexeme:
   107  //	    lexeme text in client encoding, null-terminated
   108  //	    UInt16 number of positions
   109  //	    for each position:
   110  //	        UInt16 WordEntryPos, where the most significant 2 bits is weight, and the 14 least significant bits is pos (can't be 0). Weights 3,2,1,0 represent A,B,C,D
   111  func EncodeTSVectorPGBinary(appendTo []byte, vector TSVector) ([]byte, error) {
   112  	appendTo = encoding.EncodeUint32Ascending(appendTo, uint32(len(vector)))
   113  	for _, term := range vector {
   114  		l := term.lexeme
   115  		appendTo = append(appendTo, []byte(l)...)
   116  		appendTo = append(appendTo, byte(0))
   117  		i := len(term.positions)
   118  		appendTo = encoding.EncodeUint16Ascending(appendTo, uint16(i))
   119  		for _, pos := range term.positions {
   120  			weight, err := pos.weight.TSVectorPGEncoding()
   121  			if err != nil {
   122  				return nil, err
   123  			}
   124  			out := pos.position | (uint16(weight) << 14)
   125  			appendTo = encoding.EncodeUint16Ascending(appendTo, out)
   126  		}
   127  	}
   128  	return appendTo, nil
   129  }
   130  
   131  // DecodeTSVectorPGBinary decodes a tsvector from the input byte slice which is
   132  // formatted in Postgres binary protocol.
   133  func DecodeTSVectorPGBinary(b []byte) (ret TSVector, err error) {
   134  	var nTerms uint32
   135  	var nPositions, position uint16
   136  	b, nTerms, err = encoding.DecodeUint32Ascending(b)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	ret = make([]tsTerm, nTerms)
   141  	for i := uint32(0); i < nTerms; i++ {
   142  		termIndex := bytes.IndexByte(b, byte(0))
   143  		if termIndex == -1 {
   144  			return nil, pgerror.Newf(pgcode.Syntax, "unterminated string while parsing tsvector: %s", b)
   145  		}
   146  		term := &ret[i]
   147  		term.lexeme = string(b[:termIndex])
   148  		b = b[termIndex+1:]
   149  		b, nPositions, err = encoding.DecodeUint16Ascending(b)
   150  		if err != nil {
   151  			return nil, err
   152  		}
   153  		term.positions = make([]tsPosition, nPositions)
   154  		for j := uint16(0); j < nPositions; j++ {
   155  			b, position, err = encoding.DecodeUint16Ascending(b)
   156  			if err != nil {
   157  				return nil, err
   158  			}
   159  			encodedWeight := position >> 14
   160  			weight, err := tsWeightFromVectorPGEncoding(byte(encodedWeight))
   161  			if err != nil {
   162  				return nil, err
   163  			}
   164  			// Clear the 2 most significant bits (they were used for the weight).
   165  			position = position & (^(uint16(3) << 14))
   166  			term.positions[j] = tsPosition{position: position, weight: weight}
   167  		}
   168  	}
   169  	return ret, nil
   170  }
   171  
   172  // EncodeTSQuery encodes a tsquery into a serialized representation for on-disk
   173  // storage.
   174  func EncodeTSQuery(appendTo []byte, query TSQuery) ([]byte, error) {
   175  	// First, append a uint32 of the number of nodes in the query. We'll come
   176  	// back and fill this in later.
   177  	lengthIdx := len(appendTo)
   178  	appendTo = encoding.EncodeUint32Ascending(appendTo, 0)
   179  	var encoder tsNodeCodec
   180  	var err error
   181  	appendTo, err = encoder.encodeTSNode(query.root, appendTo)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	return encoding.PutUint32Ascending(appendTo, uint32(encoder.nTokens), lengthIdx), nil
   186  }
   187  
   188  // DecodeTSQuery deserializes a serialized TSQuery in on-disk format.
   189  func DecodeTSQuery(b []byte) (ret TSQuery, err error) {
   190  	var nTokens uint32
   191  	b, nTokens, err = encoding.DecodeUint32Ascending(b)
   192  	if err != nil {
   193  		return ret, err
   194  	}
   195  	decoder := tsNodeCodec{nTokens: int(nTokens)}
   196  	_, ret.root, err = decoder.decodeTSNode(b)
   197  	if err != nil {
   198  		return ret, err
   199  	}
   200  	return ret, nil
   201  }
   202  
   203  // EncodeTSQueryPGBinary encodes a tsquery into a serialized representation.
   204  //
   205  // The below comment explains the wire protocol representation. It is taken from
   206  // this page: https://www.npgsql.org/dev/types.html
   207  //
   208  //	the tree written in prefix notation:
   209  //	First the number of tokens (a token is an operand or an operator).
   210  //	For each token:
   211  //	  UInt8 type (1 = val, 2 = oper) followed by
   212  //	  For val: UInt8 weight + UInt8 prefix (1 = yes / 0 = no) + null-terminated string,
   213  //	  For oper: UInt8 oper (1 = not, 2 = and, 3 = or, 4 = phrase).
   214  //	  In case of phrase oper code, an additional UInt16 field is sent (distance value of operator). Default is 1 for <->, otherwise the n value in '<n>'.
   215  func EncodeTSQueryPGBinary(appendTo []byte, query TSQuery) []byte {
   216  	// First, append a uint32 of the number of nodes in the query. We'll come
   217  	// back and fill this in later.
   218  	lengthIdx := len(appendTo)
   219  	appendTo = encoding.EncodeUint32Ascending(appendTo, 0)
   220  	var encoder tsNodeCodec
   221  	appendTo = encoder.encodeTSNodePGBinary(query.root, appendTo)
   222  	return encoding.PutUint32Ascending(appendTo, uint32(encoder.nTokens), lengthIdx)
   223  }
   224  
   225  // DecodeTSQueryPGBinary deserializes a serialized TSQuery in pgwire format.
   226  func DecodeTSQueryPGBinary(b []byte) (ret TSQuery, err error) {
   227  	var nTokens uint32
   228  	b, nTokens, err = encoding.DecodeUint32Ascending(b)
   229  	if err != nil {
   230  		return ret, err
   231  	}
   232  	decoder := tsNodeCodec{nTokens: int(nTokens)}
   233  	_, ret.root, err = decoder.decodeTSNodePGBinary(b)
   234  	if err != nil {
   235  		return ret, err
   236  	}
   237  	return ret, nil
   238  }
   239  
   240  type tsNodeCodec struct {
   241  	nTokens int
   242  }
   243  
   244  const (
   245  	tsNodeTypeVal  = 1
   246  	tsNodeTypeOper = 2
   247  )
   248  
   249  func (c *tsNodeCodec) encodeTSNode(node *tsNode, appendTo []byte) ([]byte, error) {
   250  	c.nTokens++
   251  	if node.op == invalid {
   252  		appendTo = append(appendTo, byte(tsNodeTypeVal))
   253  		if len(node.term.positions) > 0 {
   254  			weight := byte(node.term.positions[0].weight & (^weightStar))
   255  			appendTo = append(appendTo, weight)
   256  			prefix := byte(node.term.positions[0].weight >> 4)
   257  			appendTo = append(appendTo, prefix)
   258  		} else {
   259  			appendTo = append(appendTo, byte(0), byte(0))
   260  		}
   261  		if len(node.term.lexeme) > maxTSVectorLexemeLen {
   262  			return nil, pgerror.Newf(pgcode.ProgramLimitExceeded,
   263  				"tsvector lexeme of size %d too large (maximum is %d)", len(node.term.lexeme),
   264  				maxTSVectorLexemeLen)
   265  		}
   266  		appendTo = encoding.EncodeUntaggedBytesValue(appendTo, encoding.UnsafeConvertStringToBytes(node.term.lexeme))
   267  		return appendTo, nil
   268  	}
   269  	appendTo = append(appendTo, byte(tsNodeTypeOper))
   270  	appendTo = append(appendTo, node.op.pgwireEncoding())
   271  	if node.op == followedby {
   272  		if node.followedN > maxTSVectorFollowedBy {
   273  			return nil, pgerror.Newf(pgcode.ProgramLimitExceeded,
   274  				"tsvector followed by argument %d too large (maximum is %d)", node.followedN,
   275  				maxTSVectorLexemeLen)
   276  		}
   277  		appendTo = encoding.EncodeUint16Ascending(appendTo, node.followedN)
   278  	}
   279  	var err error
   280  	appendTo, err = c.encodeTSNode(node.l, appendTo)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	if node.r != nil {
   285  		appendTo, err = c.encodeTSNode(node.r, appendTo)
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  	}
   290  	return appendTo, nil
   291  }
   292  
   293  func (c *tsNodeCodec) encodeTSNodePGBinary(node *tsNode, appendTo []byte) []byte {
   294  	c.nTokens++
   295  	if node.op == invalid {
   296  		appendTo = append(appendTo, byte(tsNodeTypeVal))
   297  		if len(node.term.positions) > 0 {
   298  			weight := byte(node.term.positions[0].weight & (^weightStar))
   299  			appendTo = append(appendTo, weight)
   300  			prefix := byte(node.term.positions[0].weight >> 4)
   301  			appendTo = append(appendTo, prefix)
   302  		} else {
   303  			appendTo = append(appendTo, byte(0), byte(0))
   304  		}
   305  		appendTo = append(appendTo, []byte(node.term.lexeme)...)
   306  		appendTo = append(appendTo, byte(0))
   307  		return appendTo
   308  	}
   309  	appendTo = append(appendTo, byte(tsNodeTypeOper))
   310  	appendTo = append(appendTo, node.op.pgwireEncoding())
   311  	if node.op == followedby {
   312  		appendTo = encoding.EncodeUint16Ascending(appendTo, node.followedN)
   313  	}
   314  	if node.r != nil {
   315  		appendTo = c.encodeTSNodePGBinary(node.r, appendTo)
   316  	}
   317  	appendTo = c.encodeTSNodePGBinary(node.l, appendTo)
   318  	return appendTo
   319  }
   320  
   321  func getOneByte(b []byte) ([]byte, byte, error) {
   322  	if len(b) == 0 {
   323  		return nil, 0, errors.Errorf("insufficient bytes to decode byte")
   324  	}
   325  	return b[1:], b[0], nil
   326  }
   327  
   328  func (c *tsNodeCodec) decodeTSNode(b []byte) ([]byte, *tsNode, error) {
   329  	if c.nTokens == 0 {
   330  		return nil, nil, errors.Errorf("malformed tsquery: too many nodes")
   331  	}
   332  	c.nTokens--
   333  	var err error
   334  	var nodeType byte
   335  	b, nodeType, err = getOneByte(b)
   336  	if err != nil {
   337  		return nil, nil, err
   338  	}
   339  	ret := &tsNode{}
   340  	if nodeType == tsNodeTypeVal {
   341  		// We're at a leaf. Decode and return.
   342  		if len(b) < 2 {
   343  			return nil, nil, errors.Errorf("insufficient bytes to decode value weight")
   344  		}
   345  		weight, prefix := b[0], b[1]
   346  		b = b[2:]
   347  		if weight != 0 || prefix != 0 {
   348  			ret.term.positions = []tsPosition{{weight: tsWeight(weight | (prefix << 4))}}
   349  		}
   350  		// Decode the lexeme.
   351  		var lexeme []byte
   352  		b, lexeme, err = encoding.DecodeUntaggedBytesValue(b)
   353  		if err != nil {
   354  			return nil, nil, err
   355  		}
   356  		ret.term.lexeme = string(lexeme)
   357  		return b, ret, nil
   358  	}
   359  
   360  	// We're at an operator.
   361  	var operType byte
   362  	b, operType, err = getOneByte(b)
   363  	if err != nil {
   364  		return nil, nil, err
   365  	}
   366  	oper, err := tsOperatorFromPgwireEncoding(operType)
   367  	if err != nil {
   368  		return nil, nil, err
   369  	}
   370  	ret.op = oper
   371  	if oper == followedby {
   372  		var followedN uint16
   373  		b, followedN, err = encoding.DecodeUint16Ascending(b)
   374  		if err != nil {
   375  			return nil, nil, err
   376  		}
   377  		ret.followedN = followedN
   378  	}
   379  	b, ret.l, err = c.decodeTSNode(b)
   380  	if err != nil {
   381  		return nil, nil, err
   382  	}
   383  	switch oper {
   384  	// Not doesn't have a right argument.
   385  	case and, or, followedby:
   386  		b, ret.r, err = c.decodeTSNode(b)
   387  		if err != nil {
   388  			return nil, nil, err
   389  		}
   390  	}
   391  	return b, ret, nil
   392  }
   393  
   394  func (c *tsNodeCodec) decodeTSNodePGBinary(b []byte) ([]byte, *tsNode, error) {
   395  	if c.nTokens == 0 {
   396  		return nil, nil, errors.Errorf("malformed tsquery: too many nodes")
   397  	}
   398  	c.nTokens--
   399  	var err error
   400  	var nodeType byte
   401  	b, nodeType, err = getOneByte(b)
   402  	if err != nil {
   403  		return nil, nil, err
   404  	}
   405  	ret := &tsNode{}
   406  	if nodeType == tsNodeTypeVal {
   407  		// We're at a leaf. Decode and return.
   408  		if len(b) < 2 {
   409  			return nil, nil, errors.Errorf("insufficient bytes to decode value weight")
   410  		}
   411  		weight, prefix := b[0], b[1]
   412  		b = b[2:]
   413  		if weight != 0 || prefix != 0 {
   414  			ret.term.positions = []tsPosition{{weight: tsWeight(weight | (prefix << 4))}}
   415  		}
   416  		// Decode the null-terminated lexeme.
   417  		idx := bytes.IndexByte(b, 0)
   418  		if idx == -1 {
   419  			return nil, nil, errors.Errorf("no null-terminated string in tsnode")
   420  		}
   421  		ret.term.lexeme = string(b[:idx])
   422  		return b[idx+1:], ret, nil
   423  	}
   424  
   425  	// We're at an operator.
   426  	var operType byte
   427  	b, operType, err = getOneByte(b)
   428  	if err != nil {
   429  		return nil, nil, err
   430  	}
   431  	oper, err := tsOperatorFromPgwireEncoding(operType)
   432  	if err != nil {
   433  		return nil, nil, err
   434  	}
   435  	ret.op = oper
   436  	if oper == followedby {
   437  		var followedN uint16
   438  		b, followedN, err = encoding.DecodeUint16Ascending(b)
   439  		if err != nil {
   440  			return nil, nil, err
   441  		}
   442  		ret.followedN = followedN
   443  	}
   444  	switch oper {
   445  	// Not doesn't have a right argument.
   446  	case and, or, followedby:
   447  		b, ret.r, err = c.decodeTSNodePGBinary(b)
   448  		if err != nil {
   449  			return nil, nil, err
   450  		}
   451  	}
   452  	b, ret.l, err = c.decodeTSNodePGBinary(b)
   453  	if err != nil {
   454  		return nil, nil, err
   455  	}
   456  	return b, ret, nil
   457  }
   458  
   459  // EncodeInvertedIndexKeys returns a slice of byte slices, one per inverted
   460  // index key for the terms in this tsvector.
   461  func EncodeInvertedIndexKeys(inKey []byte, vector TSVector) ([][]byte, error) {
   462  	outKeys := make([][]byte, 0, len(vector))
   463  	// Note that by construction, TSVector contains only unique terms, so we don't
   464  	// need to de-duplicate terms when constructing the inverted index keys.
   465  	for i := range vector {
   466  		newKey := EncodeInvertedIndexKey(inKey, vector[i].lexeme)
   467  		outKeys = append(outKeys, newKey)
   468  	}
   469  	return outKeys, nil
   470  }
   471  
   472  // EncodeInvertedIndexKey returns the inverted index key for the input lexeme.
   473  func EncodeInvertedIndexKey(inKey []byte, lexeme string) []byte {
   474  	outKey := make([]byte, len(inKey), len(inKey)+len(lexeme))
   475  	copy(outKey, inKey)
   476  	return encoding.EncodeStringAscending(outKey, lexeme)
   477  }