github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/tsearch/encoding.go (about) 1 // Copyright 2022 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package tsearch 12 13 import ( 14 "bytes" 15 16 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode" 17 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror" 18 "github.com/cockroachdb/cockroachdb-parser/pkg/util/encoding" 19 "github.com/cockroachdb/errors" 20 ) 21 22 // EncodeTSVector encodes a tsvector into a serialized representation for 23 // on-disk storage. 24 func EncodeTSVector(appendTo []byte, vector TSVector) ([]byte, error) { 25 appendTo = encoding.EncodeUint32Ascending(appendTo, uint32(len(vector))) 26 for _, term := range vector { 27 l := term.lexeme 28 appendTo = encoding.EncodeUntaggedBytesValue(appendTo, encoding.UnsafeConvertStringToBytes(l)) 29 if len(term.positions) > maxTSVectorPositions { 30 return nil, pgerror.Newf(pgcode.ProgramLimitExceeded, 31 "tsvector position list of size %d too large (maximum is %d)", len(term.positions), 32 maxTSVectorPositions) 33 } 34 if len(l) > maxTSVectorLexemeLen { 35 return nil, pgerror.Newf(pgcode.ProgramLimitExceeded, 36 "tsvector lexeme of size %d too large (maximum is %d)", len(l), 37 maxTSVectorLexemeLen) 38 } 39 appendTo = encoding.EncodeUint16Ascending(appendTo, uint16(len(term.positions))) 40 for _, pos := range term.positions { 41 weight, err := pos.weight.TSVectorPGEncoding() 42 if err != nil { 43 return nil, err 44 } 45 // Clear the 2 most significant bits. These should never be set, 46 // as we always make sure that positions are at most 1 << 14, but 47 // better an extra check. 48 position := pos.position & (^(uint16(3) << 14)) 49 out := position | (uint16(weight) << 14) 50 appendTo = encoding.EncodeUint16Ascending(appendTo, out) 51 } 52 } 53 return appendTo, nil 54 } 55 56 // DecodeTSVector decodes a tsvector in disk-storage representation from the 57 // input byte slice. 58 func DecodeTSVector(b []byte) (ret TSVector, err error) { 59 var nTerms uint32 60 var nPositions, position uint16 61 b, nTerms, err = encoding.DecodeUint32Ascending(b) 62 if err != nil { 63 return nil, err 64 } 65 ret = make([]tsTerm, nTerms) 66 for i := uint32(0); i < nTerms; i++ { 67 var lexeme []byte 68 b, lexeme, err = encoding.DecodeUntaggedBytesValue(b) 69 if err != nil { 70 return nil, err 71 } 72 b, nPositions, err = encoding.DecodeUint16Ascending(b) 73 if err != nil { 74 return nil, err 75 } 76 term := &ret[i] 77 term.lexeme = string(lexeme) 78 term.positions = make([]tsPosition, nPositions) 79 for j := uint16(0); j < nPositions; j++ { 80 b, position, err = encoding.DecodeUint16Ascending(b) 81 if err != nil { 82 return nil, err 83 } 84 encodedWeight := position >> 14 85 weight, err := tsWeightFromVectorPGEncoding(byte(encodedWeight)) 86 if err != nil { 87 return nil, err 88 } 89 // Clear the 2 most significant bits (they were used for the weight). 90 position = position & (^(uint16(3) << 14)) 91 term.positions[j] = tsPosition{position: position, weight: weight} 92 } 93 } 94 return ret, nil 95 } 96 97 // EncodeTSVectorPGBinary encodes a tsvector into a serialized representation 98 // that's identical to Postgres's wire protocol representation. 99 // 100 // The below comment explains the wire protocol representation. It is taken from 101 // this page: https://www.npgsql.org/dev/types.html 102 // 103 // tsvector: 104 // 105 // UInt32 number of lexemes 106 // for each lexeme: 107 // lexeme text in client encoding, null-terminated 108 // UInt16 number of positions 109 // for each position: 110 // UInt16 WordEntryPos, where the most significant 2 bits is weight, and the 14 least significant bits is pos (can't be 0). Weights 3,2,1,0 represent A,B,C,D 111 func EncodeTSVectorPGBinary(appendTo []byte, vector TSVector) ([]byte, error) { 112 appendTo = encoding.EncodeUint32Ascending(appendTo, uint32(len(vector))) 113 for _, term := range vector { 114 l := term.lexeme 115 appendTo = append(appendTo, []byte(l)...) 116 appendTo = append(appendTo, byte(0)) 117 i := len(term.positions) 118 appendTo = encoding.EncodeUint16Ascending(appendTo, uint16(i)) 119 for _, pos := range term.positions { 120 weight, err := pos.weight.TSVectorPGEncoding() 121 if err != nil { 122 return nil, err 123 } 124 out := pos.position | (uint16(weight) << 14) 125 appendTo = encoding.EncodeUint16Ascending(appendTo, out) 126 } 127 } 128 return appendTo, nil 129 } 130 131 // DecodeTSVectorPGBinary decodes a tsvector from the input byte slice which is 132 // formatted in Postgres binary protocol. 133 func DecodeTSVectorPGBinary(b []byte) (ret TSVector, err error) { 134 var nTerms uint32 135 var nPositions, position uint16 136 b, nTerms, err = encoding.DecodeUint32Ascending(b) 137 if err != nil { 138 return nil, err 139 } 140 ret = make([]tsTerm, nTerms) 141 for i := uint32(0); i < nTerms; i++ { 142 termIndex := bytes.IndexByte(b, byte(0)) 143 if termIndex == -1 { 144 return nil, pgerror.Newf(pgcode.Syntax, "unterminated string while parsing tsvector: %s", b) 145 } 146 term := &ret[i] 147 term.lexeme = string(b[:termIndex]) 148 b = b[termIndex+1:] 149 b, nPositions, err = encoding.DecodeUint16Ascending(b) 150 if err != nil { 151 return nil, err 152 } 153 term.positions = make([]tsPosition, nPositions) 154 for j := uint16(0); j < nPositions; j++ { 155 b, position, err = encoding.DecodeUint16Ascending(b) 156 if err != nil { 157 return nil, err 158 } 159 encodedWeight := position >> 14 160 weight, err := tsWeightFromVectorPGEncoding(byte(encodedWeight)) 161 if err != nil { 162 return nil, err 163 } 164 // Clear the 2 most significant bits (they were used for the weight). 165 position = position & (^(uint16(3) << 14)) 166 term.positions[j] = tsPosition{position: position, weight: weight} 167 } 168 } 169 return ret, nil 170 } 171 172 // EncodeTSQuery encodes a tsquery into a serialized representation for on-disk 173 // storage. 174 func EncodeTSQuery(appendTo []byte, query TSQuery) ([]byte, error) { 175 // First, append a uint32 of the number of nodes in the query. We'll come 176 // back and fill this in later. 177 lengthIdx := len(appendTo) 178 appendTo = encoding.EncodeUint32Ascending(appendTo, 0) 179 var encoder tsNodeCodec 180 var err error 181 appendTo, err = encoder.encodeTSNode(query.root, appendTo) 182 if err != nil { 183 return nil, err 184 } 185 return encoding.PutUint32Ascending(appendTo, uint32(encoder.nTokens), lengthIdx), nil 186 } 187 188 // DecodeTSQuery deserializes a serialized TSQuery in on-disk format. 189 func DecodeTSQuery(b []byte) (ret TSQuery, err error) { 190 var nTokens uint32 191 b, nTokens, err = encoding.DecodeUint32Ascending(b) 192 if err != nil { 193 return ret, err 194 } 195 decoder := tsNodeCodec{nTokens: int(nTokens)} 196 _, ret.root, err = decoder.decodeTSNode(b) 197 if err != nil { 198 return ret, err 199 } 200 return ret, nil 201 } 202 203 // EncodeTSQueryPGBinary encodes a tsquery into a serialized representation. 204 // 205 // The below comment explains the wire protocol representation. It is taken from 206 // this page: https://www.npgsql.org/dev/types.html 207 // 208 // the tree written in prefix notation: 209 // First the number of tokens (a token is an operand or an operator). 210 // For each token: 211 // UInt8 type (1 = val, 2 = oper) followed by 212 // For val: UInt8 weight + UInt8 prefix (1 = yes / 0 = no) + null-terminated string, 213 // For oper: UInt8 oper (1 = not, 2 = and, 3 = or, 4 = phrase). 214 // In case of phrase oper code, an additional UInt16 field is sent (distance value of operator). Default is 1 for <->, otherwise the n value in '<n>'. 215 func EncodeTSQueryPGBinary(appendTo []byte, query TSQuery) []byte { 216 // First, append a uint32 of the number of nodes in the query. We'll come 217 // back and fill this in later. 218 lengthIdx := len(appendTo) 219 appendTo = encoding.EncodeUint32Ascending(appendTo, 0) 220 var encoder tsNodeCodec 221 appendTo = encoder.encodeTSNodePGBinary(query.root, appendTo) 222 return encoding.PutUint32Ascending(appendTo, uint32(encoder.nTokens), lengthIdx) 223 } 224 225 // DecodeTSQueryPGBinary deserializes a serialized TSQuery in pgwire format. 226 func DecodeTSQueryPGBinary(b []byte) (ret TSQuery, err error) { 227 var nTokens uint32 228 b, nTokens, err = encoding.DecodeUint32Ascending(b) 229 if err != nil { 230 return ret, err 231 } 232 decoder := tsNodeCodec{nTokens: int(nTokens)} 233 _, ret.root, err = decoder.decodeTSNodePGBinary(b) 234 if err != nil { 235 return ret, err 236 } 237 return ret, nil 238 } 239 240 type tsNodeCodec struct { 241 nTokens int 242 } 243 244 const ( 245 tsNodeTypeVal = 1 246 tsNodeTypeOper = 2 247 ) 248 249 func (c *tsNodeCodec) encodeTSNode(node *tsNode, appendTo []byte) ([]byte, error) { 250 c.nTokens++ 251 if node.op == invalid { 252 appendTo = append(appendTo, byte(tsNodeTypeVal)) 253 if len(node.term.positions) > 0 { 254 weight := byte(node.term.positions[0].weight & (^weightStar)) 255 appendTo = append(appendTo, weight) 256 prefix := byte(node.term.positions[0].weight >> 4) 257 appendTo = append(appendTo, prefix) 258 } else { 259 appendTo = append(appendTo, byte(0), byte(0)) 260 } 261 if len(node.term.lexeme) > maxTSVectorLexemeLen { 262 return nil, pgerror.Newf(pgcode.ProgramLimitExceeded, 263 "tsvector lexeme of size %d too large (maximum is %d)", len(node.term.lexeme), 264 maxTSVectorLexemeLen) 265 } 266 appendTo = encoding.EncodeUntaggedBytesValue(appendTo, encoding.UnsafeConvertStringToBytes(node.term.lexeme)) 267 return appendTo, nil 268 } 269 appendTo = append(appendTo, byte(tsNodeTypeOper)) 270 appendTo = append(appendTo, node.op.pgwireEncoding()) 271 if node.op == followedby { 272 if node.followedN > maxTSVectorFollowedBy { 273 return nil, pgerror.Newf(pgcode.ProgramLimitExceeded, 274 "tsvector followed by argument %d too large (maximum is %d)", node.followedN, 275 maxTSVectorLexemeLen) 276 } 277 appendTo = encoding.EncodeUint16Ascending(appendTo, node.followedN) 278 } 279 var err error 280 appendTo, err = c.encodeTSNode(node.l, appendTo) 281 if err != nil { 282 return nil, err 283 } 284 if node.r != nil { 285 appendTo, err = c.encodeTSNode(node.r, appendTo) 286 if err != nil { 287 return nil, err 288 } 289 } 290 return appendTo, nil 291 } 292 293 func (c *tsNodeCodec) encodeTSNodePGBinary(node *tsNode, appendTo []byte) []byte { 294 c.nTokens++ 295 if node.op == invalid { 296 appendTo = append(appendTo, byte(tsNodeTypeVal)) 297 if len(node.term.positions) > 0 { 298 weight := byte(node.term.positions[0].weight & (^weightStar)) 299 appendTo = append(appendTo, weight) 300 prefix := byte(node.term.positions[0].weight >> 4) 301 appendTo = append(appendTo, prefix) 302 } else { 303 appendTo = append(appendTo, byte(0), byte(0)) 304 } 305 appendTo = append(appendTo, []byte(node.term.lexeme)...) 306 appendTo = append(appendTo, byte(0)) 307 return appendTo 308 } 309 appendTo = append(appendTo, byte(tsNodeTypeOper)) 310 appendTo = append(appendTo, node.op.pgwireEncoding()) 311 if node.op == followedby { 312 appendTo = encoding.EncodeUint16Ascending(appendTo, node.followedN) 313 } 314 if node.r != nil { 315 appendTo = c.encodeTSNodePGBinary(node.r, appendTo) 316 } 317 appendTo = c.encodeTSNodePGBinary(node.l, appendTo) 318 return appendTo 319 } 320 321 func getOneByte(b []byte) ([]byte, byte, error) { 322 if len(b) == 0 { 323 return nil, 0, errors.Errorf("insufficient bytes to decode byte") 324 } 325 return b[1:], b[0], nil 326 } 327 328 func (c *tsNodeCodec) decodeTSNode(b []byte) ([]byte, *tsNode, error) { 329 if c.nTokens == 0 { 330 return nil, nil, errors.Errorf("malformed tsquery: too many nodes") 331 } 332 c.nTokens-- 333 var err error 334 var nodeType byte 335 b, nodeType, err = getOneByte(b) 336 if err != nil { 337 return nil, nil, err 338 } 339 ret := &tsNode{} 340 if nodeType == tsNodeTypeVal { 341 // We're at a leaf. Decode and return. 342 if len(b) < 2 { 343 return nil, nil, errors.Errorf("insufficient bytes to decode value weight") 344 } 345 weight, prefix := b[0], b[1] 346 b = b[2:] 347 if weight != 0 || prefix != 0 { 348 ret.term.positions = []tsPosition{{weight: tsWeight(weight | (prefix << 4))}} 349 } 350 // Decode the lexeme. 351 var lexeme []byte 352 b, lexeme, err = encoding.DecodeUntaggedBytesValue(b) 353 if err != nil { 354 return nil, nil, err 355 } 356 ret.term.lexeme = string(lexeme) 357 return b, ret, nil 358 } 359 360 // We're at an operator. 361 var operType byte 362 b, operType, err = getOneByte(b) 363 if err != nil { 364 return nil, nil, err 365 } 366 oper, err := tsOperatorFromPgwireEncoding(operType) 367 if err != nil { 368 return nil, nil, err 369 } 370 ret.op = oper 371 if oper == followedby { 372 var followedN uint16 373 b, followedN, err = encoding.DecodeUint16Ascending(b) 374 if err != nil { 375 return nil, nil, err 376 } 377 ret.followedN = followedN 378 } 379 b, ret.l, err = c.decodeTSNode(b) 380 if err != nil { 381 return nil, nil, err 382 } 383 switch oper { 384 // Not doesn't have a right argument. 385 case and, or, followedby: 386 b, ret.r, err = c.decodeTSNode(b) 387 if err != nil { 388 return nil, nil, err 389 } 390 } 391 return b, ret, nil 392 } 393 394 func (c *tsNodeCodec) decodeTSNodePGBinary(b []byte) ([]byte, *tsNode, error) { 395 if c.nTokens == 0 { 396 return nil, nil, errors.Errorf("malformed tsquery: too many nodes") 397 } 398 c.nTokens-- 399 var err error 400 var nodeType byte 401 b, nodeType, err = getOneByte(b) 402 if err != nil { 403 return nil, nil, err 404 } 405 ret := &tsNode{} 406 if nodeType == tsNodeTypeVal { 407 // We're at a leaf. Decode and return. 408 if len(b) < 2 { 409 return nil, nil, errors.Errorf("insufficient bytes to decode value weight") 410 } 411 weight, prefix := b[0], b[1] 412 b = b[2:] 413 if weight != 0 || prefix != 0 { 414 ret.term.positions = []tsPosition{{weight: tsWeight(weight | (prefix << 4))}} 415 } 416 // Decode the null-terminated lexeme. 417 idx := bytes.IndexByte(b, 0) 418 if idx == -1 { 419 return nil, nil, errors.Errorf("no null-terminated string in tsnode") 420 } 421 ret.term.lexeme = string(b[:idx]) 422 return b[idx+1:], ret, nil 423 } 424 425 // We're at an operator. 426 var operType byte 427 b, operType, err = getOneByte(b) 428 if err != nil { 429 return nil, nil, err 430 } 431 oper, err := tsOperatorFromPgwireEncoding(operType) 432 if err != nil { 433 return nil, nil, err 434 } 435 ret.op = oper 436 if oper == followedby { 437 var followedN uint16 438 b, followedN, err = encoding.DecodeUint16Ascending(b) 439 if err != nil { 440 return nil, nil, err 441 } 442 ret.followedN = followedN 443 } 444 switch oper { 445 // Not doesn't have a right argument. 446 case and, or, followedby: 447 b, ret.r, err = c.decodeTSNodePGBinary(b) 448 if err != nil { 449 return nil, nil, err 450 } 451 } 452 b, ret.l, err = c.decodeTSNodePGBinary(b) 453 if err != nil { 454 return nil, nil, err 455 } 456 return b, ret, nil 457 } 458 459 // EncodeInvertedIndexKeys returns a slice of byte slices, one per inverted 460 // index key for the terms in this tsvector. 461 func EncodeInvertedIndexKeys(inKey []byte, vector TSVector) ([][]byte, error) { 462 outKeys := make([][]byte, 0, len(vector)) 463 // Note that by construction, TSVector contains only unique terms, so we don't 464 // need to de-duplicate terms when constructing the inverted index keys. 465 for i := range vector { 466 newKey := EncodeInvertedIndexKey(inKey, vector[i].lexeme) 467 outKeys = append(outKeys, newKey) 468 } 469 return outKeys, nil 470 } 471 472 // EncodeInvertedIndexKey returns the inverted index key for the input lexeme. 473 func EncodeInvertedIndexKey(inKey []byte, lexeme string) []byte { 474 outKey := make([]byte, len(inKey), len(inKey)+len(lexeme)) 475 copy(outKey, inKey) 476 return encoding.EncodeStringAscending(outKey, lexeme) 477 }