github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/token.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "fmt" 21 "strconv" 22 "strings" 23 24 "github.com/team-ide/go-dialect/vitess/sqltypes" 25 ) 26 27 const ( 28 eofChar = 0x100 29 ) 30 31 // Tokenizer is the struct used to generate SQL 32 // tokens for the parser. 33 type Tokenizer struct { 34 AllowComments bool 35 SkipSpecialComments bool 36 SkipToEnd bool 37 LastError error 38 ParseTree Statement 39 BindVars map[string]struct{} 40 41 lastToken string 42 posVarIndex int 43 partialDDL Statement 44 nesting int 45 multi bool 46 specialComment *Tokenizer 47 48 Pos int 49 buf string 50 51 // MySQLVersion is the version of MySQL that the parser would emulate 52 MySQLVersion string 53 } 54 55 func (tkn *Tokenizer) GetMySQLVersion() string { 56 if tkn.MySQLVersion == "" { 57 return "50709" // default version if nothing else is stated 58 } 59 return tkn.MySQLVersion 60 } 61 62 func (tkn *Tokenizer) SetMySQLVersion(version string) (err error) { 63 convVersion, err := convertMySQLVersionToCommentVersion(version) 64 if err != nil { 65 } else { 66 tkn.MySQLVersion = convVersion 67 } 68 return 69 } 70 71 // NewStringTokenizer creates a new Tokenizer for the 72 // sql string. 73 func NewStringTokenizer(sql string) *Tokenizer { 74 75 return &Tokenizer{ 76 buf: sql, 77 BindVars: make(map[string]struct{}), 78 } 79 } 80 81 // Lex returns the next token form the Tokenizer. 82 // This function is used by go yacc. 83 func (tkn *Tokenizer) Lex(lval *yySymType) int { 84 if tkn.SkipToEnd { 85 return tkn.skipStatement() 86 } 87 88 typ, val := tkn.Scan() 89 for typ == COMMENT { 90 if tkn.AllowComments { 91 break 92 } 93 typ, val = tkn.Scan() 94 } 95 if typ == 0 || typ == ';' || typ == LEX_ERROR { 96 // If encounter end of statement or invalid token, 97 // we should not accept partially parsed DDLs. They 98 // should instead result in parser errors. See the 99 // Parse function to see how this is handled. 100 tkn.partialDDL = nil 101 } 102 lval.str = val 103 tkn.lastToken = val 104 return typ 105 } 106 107 // PositionedErr holds context related to parser errors 108 type PositionedErr struct { 109 Err string 110 Pos int 111 Near string 112 } 113 114 func (p PositionedErr) Error() string { 115 if p.Near != "" { 116 return fmt.Sprintf("%s at position %v near '%s'", p.Err, p.Pos, p.Near) 117 } 118 return fmt.Sprintf("%s at position %v", p.Err, p.Pos) 119 } 120 121 // Error is called by go yacc if there's a parsing error. 122 func (tkn *Tokenizer) Error(err string) { 123 tkn.LastError = PositionedErr{Err: err, Pos: tkn.Pos + 1, Near: tkn.lastToken} 124 125 // Try and re-sync to the next statement 126 tkn.skipStatement() 127 } 128 129 // Scan scans the tokenizer for the next token and returns 130 // the token type and an optional value. 131 func (tkn *Tokenizer) Scan() (int, string) { 132 if tkn.specialComment != nil { 133 // Enter specialComment scan mode. 134 // for scanning such kind of comment: /*! MySQL-specific code */ 135 specialComment := tkn.specialComment 136 tok, val := specialComment.Scan() 137 if tok != 0 { 138 // return the specialComment scan result as the result 139 return tok, val 140 } 141 // leave specialComment scan mode after all stream consumed. 142 tkn.specialComment = nil 143 } 144 145 tkn.skipBlank() 146 switch ch := tkn.cur(); { 147 case ch == '@': 148 tokenID := AT_ID 149 tkn.skip(1) 150 if tkn.cur() == '@' { 151 tokenID = AT_AT_ID 152 tkn.skip(1) 153 } 154 var tID int 155 var tBytes string 156 if tkn.cur() == '`' { 157 tkn.skip(1) 158 tID, tBytes = tkn.scanLiteralIdentifier() 159 } else if tkn.cur() == eofChar { 160 return LEX_ERROR, "" 161 } else { 162 tID, tBytes = tkn.scanIdentifier(true) 163 } 164 if tID == LEX_ERROR { 165 return tID, "" 166 } 167 return tokenID, tBytes 168 case isLetter(ch): 169 if ch == 'X' || ch == 'x' { 170 if tkn.peek(1) == '\'' { 171 tkn.skip(2) 172 return tkn.scanHex() 173 } 174 } 175 if ch == 'B' || ch == 'b' { 176 if tkn.peek(1) == '\'' { 177 tkn.skip(2) 178 return tkn.scanBitLiteral() 179 } 180 } 181 // N\'literal' is used to create a string in the national character set 182 if ch == 'N' || ch == 'n' { 183 nxt := tkn.peek(1) 184 if nxt == '\'' || nxt == '"' { 185 tkn.skip(2) 186 return tkn.scanString(nxt, NCHAR_STRING) 187 } 188 } 189 return tkn.scanIdentifier(false) 190 case isDigit(ch): 191 return tkn.scanNumber() 192 case ch == ':': 193 return tkn.scanBindVar() 194 case ch == ';': 195 if tkn.multi { 196 // In multi mode, ';' is treated as EOF. So, we don't advance. 197 // Repeated calls to Scan will keep returning 0 until ParseNext 198 // forces the advance. 199 return 0, "" 200 } 201 tkn.skip(1) 202 return ';', "" 203 case ch == eofChar: 204 return 0, "" 205 default: 206 if ch == '.' && isDigit(tkn.peek(1)) { 207 return tkn.scanNumber() 208 } 209 210 tkn.skip(1) 211 switch ch { 212 case '=', ',', '(', ')', '+', '*', '%', '^', '~': 213 return int(ch), "" 214 case '&': 215 if tkn.cur() == '&' { 216 tkn.skip(1) 217 return AND, "" 218 } 219 return int(ch), "" 220 case '|': 221 if tkn.cur() == '|' { 222 tkn.skip(1) 223 return OR, "" 224 } 225 return int(ch), "" 226 case '?': 227 tkn.posVarIndex++ 228 buf := make([]byte, 0, 8) 229 buf = append(buf, ":v"...) 230 buf = strconv.AppendInt(buf, int64(tkn.posVarIndex), 10) 231 return VALUE_ARG, string(buf) 232 case '.': 233 return int(ch), "" 234 case '/': 235 switch tkn.cur() { 236 case '/': 237 tkn.skip(1) 238 return tkn.scanCommentType1(2) 239 case '*': 240 tkn.skip(1) 241 if tkn.cur() == '!' && !tkn.SkipSpecialComments { 242 tkn.skip(1) 243 return tkn.scanMySQLSpecificComment() 244 } 245 return tkn.scanCommentType2() 246 default: 247 return int(ch), "" 248 } 249 case '#': 250 return tkn.scanCommentType1(1) 251 case '-': 252 switch tkn.cur() { 253 case '-': 254 nextChar := tkn.peek(1) 255 if nextChar == ' ' || nextChar == '\n' || nextChar == '\t' || nextChar == '\r' || nextChar == eofChar { 256 tkn.skip(1) 257 return tkn.scanCommentType1(2) 258 } 259 case '>': 260 tkn.skip(1) 261 if tkn.cur() == '>' { 262 tkn.skip(1) 263 return JSON_UNQUOTE_EXTRACT_OP, "" 264 } 265 return JSON_EXTRACT_OP, "" 266 } 267 return int(ch), "" 268 case '<': 269 switch tkn.cur() { 270 case '>': 271 tkn.skip(1) 272 return NE, "" 273 case '<': 274 tkn.skip(1) 275 return SHIFT_LEFT, "" 276 case '=': 277 tkn.skip(1) 278 switch tkn.cur() { 279 case '>': 280 tkn.skip(1) 281 return NULL_SAFE_EQUAL, "" 282 default: 283 return LE, "" 284 } 285 default: 286 return int(ch), "" 287 } 288 case '>': 289 switch tkn.cur() { 290 case '=': 291 tkn.skip(1) 292 return GE, "" 293 case '>': 294 tkn.skip(1) 295 return SHIFT_RIGHT, "" 296 default: 297 return int(ch), "" 298 } 299 case '!': 300 if tkn.cur() == '=' { 301 tkn.skip(1) 302 return NE, "" 303 } 304 return int(ch), "" 305 case '\'', '"': 306 return tkn.scanString(ch, STRING) 307 case '`': 308 return tkn.scanLiteralIdentifier() 309 default: 310 return LEX_ERROR, string(byte(ch)) 311 } 312 } 313 } 314 315 // skipStatement scans until end of statement. 316 func (tkn *Tokenizer) skipStatement() int { 317 tkn.SkipToEnd = false 318 for { 319 typ, _ := tkn.Scan() 320 if typ == 0 || typ == ';' || typ == LEX_ERROR { 321 return typ 322 } 323 } 324 } 325 326 // skipBlank skips the cursor while it finds whitespace 327 func (tkn *Tokenizer) skipBlank() { 328 ch := tkn.cur() 329 for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' { 330 tkn.skip(1) 331 ch = tkn.cur() 332 } 333 } 334 335 // scanIdentifier scans a language keyword or @-encased variable 336 func (tkn *Tokenizer) scanIdentifier(isVariable bool) (int, string) { 337 start := tkn.Pos 338 tkn.skip(1) 339 340 for { 341 ch := tkn.cur() 342 if !isLetter(ch) && !isDigit(ch) && ch != '@' && !(isVariable && isCarat(ch)) { 343 break 344 } 345 if ch == '@' { 346 isVariable = true 347 } 348 tkn.skip(1) 349 } 350 keywordName := tkn.buf[start:tkn.Pos] 351 if keywordID, found := keywordLookupTable.LookupString(keywordName); found { 352 return keywordID, keywordName 353 } 354 // dual must always be case-insensitive 355 if keywordASCIIMatch(keywordName, "dual") { 356 return ID, "dual" 357 } 358 return ID, keywordName 359 } 360 361 // scanHex scans a hex numeral; assumes x' or X' has already been scanned 362 func (tkn *Tokenizer) scanHex() (int, string) { 363 start := tkn.Pos 364 tkn.scanMantissa(16) 365 hex := tkn.buf[start:tkn.Pos] 366 if tkn.cur() != '\'' { 367 return LEX_ERROR, hex 368 } 369 tkn.skip(1) 370 if len(hex)%2 != 0 { 371 return LEX_ERROR, hex 372 } 373 return HEX, hex 374 } 375 376 // scanBitLiteral scans a binary numeric literal; assumes b' or B' has already been scanned 377 func (tkn *Tokenizer) scanBitLiteral() (int, string) { 378 start := tkn.Pos 379 tkn.scanMantissa(2) 380 bit := tkn.buf[start:tkn.Pos] 381 if tkn.cur() != '\'' { 382 return LEX_ERROR, bit 383 } 384 tkn.skip(1) 385 return BIT_LITERAL, bit 386 } 387 388 // scanLiteralIdentifierSlow scans an identifier surrounded by backticks which may 389 // contain escape sequences instead of it. This method is only called from 390 // scanLiteralIdentifier once the first escape sequence is found in the identifier. 391 // The provided `buf` contains the contents of the identifier that have been scanned 392 // so far. 393 func (tkn *Tokenizer) scanLiteralIdentifierSlow(buf *strings.Builder) (int, string) { 394 backTickSeen := true 395 for { 396 if backTickSeen { 397 if tkn.cur() != '`' { 398 break 399 } 400 backTickSeen = false 401 buf.WriteByte('`') 402 tkn.skip(1) 403 continue 404 } 405 // The previous char was not a backtick. 406 switch tkn.cur() { 407 case '`': 408 backTickSeen = true 409 case eofChar: 410 // Premature EOF. 411 return LEX_ERROR, buf.String() 412 default: 413 buf.WriteByte(byte(tkn.cur())) 414 // keep scanning 415 } 416 tkn.skip(1) 417 } 418 return ID, buf.String() 419 } 420 421 // scanLiteralIdentifier scans an identifier enclosed by backticks. If the identifier 422 // is a simple literal, it'll be returned as a slice of the input buffer. If the identifier 423 // contains escape sequences, this function will fall back to scanLiteralIdentifierSlow 424 func (tkn *Tokenizer) scanLiteralIdentifier() (int, string) { 425 start := tkn.Pos 426 for { 427 switch tkn.cur() { 428 case '`': 429 if tkn.peek(1) != '`' { 430 if tkn.Pos == start { 431 return LEX_ERROR, "" 432 } 433 tkn.skip(1) 434 return ID, tkn.buf[start : tkn.Pos-1] 435 } 436 437 var buf strings.Builder 438 buf.WriteString(tkn.buf[start:tkn.Pos]) 439 tkn.skip(1) 440 return tkn.scanLiteralIdentifierSlow(&buf) 441 case eofChar: 442 // Premature EOF. 443 return LEX_ERROR, tkn.buf[start:tkn.Pos] 444 default: 445 tkn.skip(1) 446 } 447 } 448 } 449 450 // scanBindVar scans a bind variable; assumes a ':' has been scanned right before 451 func (tkn *Tokenizer) scanBindVar() (int, string) { 452 start := tkn.Pos 453 token := VALUE_ARG 454 455 tkn.skip(1) 456 if tkn.cur() == ':' { 457 token = LIST_ARG 458 tkn.skip(1) 459 } 460 if !isLetter(tkn.cur()) { 461 return LEX_ERROR, tkn.buf[start:tkn.Pos] 462 } 463 for { 464 ch := tkn.cur() 465 if !isLetter(ch) && !isDigit(ch) && ch != '.' { 466 break 467 } 468 tkn.skip(1) 469 } 470 return token, tkn.buf[start:tkn.Pos] 471 } 472 473 // scanMantissa scans a sequence of numeric characters with the same base. 474 // This is a helper function only called from the numeric scanners 475 func (tkn *Tokenizer) scanMantissa(base int) { 476 for digitVal(tkn.cur()) < base { 477 tkn.skip(1) 478 } 479 } 480 481 // scanNumber scans any SQL numeric literal, either floating point or integer 482 func (tkn *Tokenizer) scanNumber() (int, string) { 483 start := tkn.Pos 484 token := INTEGRAL 485 486 if tkn.cur() == '.' { 487 token = DECIMAL 488 tkn.skip(1) 489 tkn.scanMantissa(10) 490 goto exponent 491 } 492 493 // 0x construct. 494 if tkn.cur() == '0' { 495 tkn.skip(1) 496 if tkn.cur() == 'x' || tkn.cur() == 'X' { 497 token = HEXNUM 498 tkn.skip(1) 499 tkn.scanMantissa(16) 500 goto exit 501 } 502 } 503 504 tkn.scanMantissa(10) 505 506 if tkn.cur() == '.' { 507 token = DECIMAL 508 tkn.skip(1) 509 tkn.scanMantissa(10) 510 } 511 512 exponent: 513 if tkn.cur() == 'e' || tkn.cur() == 'E' { 514 token = FLOAT 515 tkn.skip(1) 516 if tkn.cur() == '+' || tkn.cur() == '-' { 517 tkn.skip(1) 518 } 519 tkn.scanMantissa(10) 520 } 521 522 exit: 523 if isLetter(tkn.cur()) { 524 // A letter cannot immediately follow a float number. 525 if token == FLOAT || token == DECIMAL { 526 return LEX_ERROR, tkn.buf[start:tkn.Pos] 527 } 528 // A letter seen after a few numbers means that we should parse this 529 // as an identifier and not a number. 530 for { 531 ch := tkn.cur() 532 if !isLetter(ch) && !isDigit(ch) { 533 break 534 } 535 tkn.skip(1) 536 } 537 return ID, tkn.buf[start:tkn.Pos] 538 } 539 540 return token, tkn.buf[start:tkn.Pos] 541 } 542 543 // scanString scans a string surrounded by the given `delim`, which can be 544 // either single or double quotes. Assumes that the given delimiter has just 545 // been scanned. If the skin contains any escape sequences, this function 546 // will fall back to scanStringSlow 547 func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, string) { 548 start := tkn.Pos 549 550 for { 551 switch tkn.cur() { 552 case delim: 553 if tkn.peek(1) != delim { 554 tkn.skip(1) 555 return typ, tkn.buf[start : tkn.Pos-1] 556 } 557 fallthrough 558 559 case '\\': 560 var buffer strings.Builder 561 buffer.WriteString(tkn.buf[start:tkn.Pos]) 562 return tkn.scanStringSlow(&buffer, delim, typ) 563 564 case eofChar: 565 return LEX_ERROR, tkn.buf[start:tkn.Pos] 566 } 567 568 tkn.skip(1) 569 } 570 } 571 572 // scanString scans a string surrounded by the given `delim` and containing escape 573 // sequencse. The given `buffer` contains the contents of the string that have 574 // been scanned so far. 575 func (tkn *Tokenizer) scanStringSlow(buffer *strings.Builder, delim uint16, typ int) (int, string) { 576 for { 577 ch := tkn.cur() 578 if ch == eofChar { 579 // Unterminated string. 580 return LEX_ERROR, buffer.String() 581 } 582 583 if ch != delim && ch != '\\' { 584 // Scan ahead to the next interesting character. 585 start := tkn.Pos 586 for ; tkn.Pos < len(tkn.buf); tkn.Pos++ { 587 ch = uint16(tkn.buf[tkn.Pos]) 588 if ch == delim || ch == '\\' { 589 break 590 } 591 } 592 593 buffer.WriteString(tkn.buf[start:tkn.Pos]) 594 if tkn.Pos >= len(tkn.buf) { 595 // Reached the end of the buffer without finding a delim or 596 // escape character. 597 tkn.skip(1) 598 continue 599 } 600 } 601 tkn.skip(1) // Read one past the delim or escape character. 602 603 if ch == '\\' { 604 if tkn.cur() == eofChar { 605 // String terminates mid escape character. 606 return LEX_ERROR, buffer.String() 607 } 608 if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.cur())]; decodedChar == sqltypes.DontEscape { 609 ch = tkn.cur() 610 } else { 611 ch = uint16(decodedChar) 612 } 613 } else if ch == delim && tkn.cur() != delim { 614 // Correctly terminated string, which is not a double delim. 615 break 616 } 617 618 buffer.WriteByte(byte(ch)) 619 tkn.skip(1) 620 } 621 622 return typ, buffer.String() 623 } 624 625 // scanCommentType1 scans a SQL line-comment, which is applied until the end 626 // of the line. The given prefix length varies based on whether the comment 627 // is started with '//', '--' or '#'. 628 func (tkn *Tokenizer) scanCommentType1(prefixLen int) (int, string) { 629 start := tkn.Pos - prefixLen 630 for tkn.cur() != eofChar { 631 if tkn.cur() == '\n' { 632 tkn.skip(1) 633 break 634 } 635 tkn.skip(1) 636 } 637 return COMMENT, tkn.buf[start:tkn.Pos] 638 } 639 640 // scanCommentType2 scans a '/*' delimited comment; assumes the opening 641 // prefix has already been scanned 642 func (tkn *Tokenizer) scanCommentType2() (int, string) { 643 start := tkn.Pos - 2 644 for { 645 if tkn.cur() == '*' { 646 tkn.skip(1) 647 if tkn.cur() == '/' { 648 tkn.skip(1) 649 break 650 } 651 continue 652 } 653 if tkn.cur() == eofChar { 654 return LEX_ERROR, tkn.buf[start:tkn.Pos] 655 } 656 tkn.skip(1) 657 } 658 return COMMENT, tkn.buf[start:tkn.Pos] 659 } 660 661 // scanMySQLSpecificComment scans a MySQL comment pragma, which always starts with '//*` 662 func (tkn *Tokenizer) scanMySQLSpecificComment() (int, string) { 663 start := tkn.Pos - 3 664 for { 665 if tkn.cur() == '*' { 666 tkn.skip(1) 667 if tkn.cur() == '/' { 668 tkn.skip(1) 669 break 670 } 671 continue 672 } 673 if tkn.cur() == eofChar { 674 return LEX_ERROR, tkn.buf[start:tkn.Pos] 675 } 676 tkn.skip(1) 677 } 678 679 commentVersion, sql := ExtractMysqlComment(tkn.buf[start:tkn.Pos]) 680 681 if tkn.GetMySQLVersion() >= commentVersion { 682 // Only add the special comment to the tokenizer if the version of MySQL is higher or equal to the comment version 683 tkn.specialComment = NewStringTokenizer(sql) 684 } 685 686 return tkn.Scan() 687 } 688 689 func (tkn *Tokenizer) cur() uint16 { 690 return tkn.peek(0) 691 } 692 693 func (tkn *Tokenizer) skip(dist int) { 694 tkn.Pos += dist 695 } 696 697 func (tkn *Tokenizer) peek(dist int) uint16 { 698 if tkn.Pos+dist >= len(tkn.buf) { 699 return eofChar 700 } 701 return uint16(tkn.buf[tkn.Pos+dist]) 702 } 703 704 // reset clears any internal state. 705 func (tkn *Tokenizer) reset() { 706 tkn.ParseTree = nil 707 tkn.partialDDL = nil 708 tkn.specialComment = nil 709 tkn.posVarIndex = 0 710 tkn.nesting = 0 711 tkn.SkipToEnd = false 712 } 713 714 func isLetter(ch uint16) bool { 715 return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '$' 716 } 717 718 func isCarat(ch uint16) bool { 719 return ch == '.' || ch == '\'' || ch == '"' || ch == '`' 720 } 721 722 func digitVal(ch uint16) int { 723 switch { 724 case '0' <= ch && ch <= '9': 725 return int(ch) - '0' 726 case 'a' <= ch && ch <= 'f': 727 return int(ch) - 'a' + 10 728 case 'A' <= ch && ch <= 'F': 729 return int(ch) - 'A' + 10 730 } 731 return 16 // larger than any legal digit val 732 } 733 734 func isDigit(ch uint16) bool { 735 return '0' <= ch && ch <= '9' 736 }