github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqlparser/token.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "fmt" 21 "strconv" 22 "strings" 23 24 "github.com/vedadiyan/sqlparser/pkg/sqltypes" 25 ) 26 27 const ( 28 eofChar = 0x100 29 ) 30 31 // Tokenizer is the struct used to generate SQL 32 // tokens for the parser. 33 type Tokenizer struct { 34 AllowComments bool 35 SkipSpecialComments bool 36 SkipToEnd bool 37 LastError error 38 ParseTree Statement 39 BindVars map[string]struct{} 40 41 lastToken string 42 posVarIndex int 43 partialDDL Statement 44 nesting int 45 multi bool 46 specialComment *Tokenizer 47 48 Pos int 49 buf string 50 } 51 52 // NewStringTokenizer creates a new Tokenizer for the 53 // sql string. 54 func NewStringTokenizer(sql string) *Tokenizer { 55 return &Tokenizer{ 56 buf: sql, 57 BindVars: make(map[string]struct{}), 58 } 59 } 60 61 // Lex returns the next token form the Tokenizer. 62 // This function is used by go yacc. 63 func (tkn *Tokenizer) Lex(lval *yySymType) int { 64 if tkn.SkipToEnd { 65 return tkn.skipStatement() 66 } 67 68 typ, val := tkn.Scan() 69 for typ == COMMENT { 70 if tkn.AllowComments { 71 break 72 } 73 typ, val = tkn.Scan() 74 } 75 if typ == 0 || typ == ';' || typ == LEX_ERROR { 76 // If encounter end of statement or invalid token, 77 // we should not accept partially parsed DDLs. They 78 // should instead result in parser errors. See the 79 // Parse function to see how this is handled. 80 tkn.partialDDL = nil 81 } 82 lval.str = val 83 tkn.lastToken = val 84 return typ 85 } 86 87 // PositionedErr holds context related to parser errors 88 type PositionedErr struct { 89 Err string 90 Pos int 91 Near string 92 } 93 94 func (p PositionedErr) Error() string { 95 if p.Near != "" { 96 return fmt.Sprintf("%s at position %v near '%s'", p.Err, p.Pos, p.Near) 97 } 98 return fmt.Sprintf("%s at position %v", p.Err, p.Pos) 99 } 100 101 // Error is called by go yacc if there's a parsing error. 102 func (tkn *Tokenizer) Error(err string) { 103 tkn.LastError = PositionedErr{Err: err, Pos: tkn.Pos + 1, Near: tkn.lastToken} 104 105 // Try and re-sync to the next statement 106 tkn.skipStatement() 107 } 108 109 // Scan scans the tokenizer for the next token and returns 110 // the token type and an optional value. 111 func (tkn *Tokenizer) Scan() (int, string) { 112 if tkn.specialComment != nil { 113 // Enter specialComment scan mode. 114 // for scanning such kind of comment: /*! MySQL-specific code */ 115 specialComment := tkn.specialComment 116 tok, val := specialComment.Scan() 117 if tok != 0 { 118 // return the specialComment scan result as the result 119 return tok, val 120 } 121 // leave specialComment scan mode after all stream consumed. 122 tkn.specialComment = nil 123 } 124 125 tkn.skipBlank() 126 switch ch := tkn.cur(); { 127 case ch == '@': 128 tokenID := AT_ID 129 tkn.skip(1) 130 if tkn.cur() == '@' { 131 tokenID = AT_AT_ID 132 tkn.skip(1) 133 } 134 var tID int 135 var tBytes string 136 if tkn.cur() == '`' { 137 tkn.skip(1) 138 tID, tBytes = tkn.scanLiteralIdentifier() 139 } else if tkn.cur() == eofChar { 140 return LEX_ERROR, "" 141 } else { 142 tID, tBytes = tkn.scanIdentifier(true) 143 } 144 if tID == LEX_ERROR { 145 return tID, "" 146 } 147 return tokenID, tBytes 148 case isLetter(ch): 149 if ch == 'X' || ch == 'x' { 150 if tkn.peek(1) == '\'' { 151 tkn.skip(2) 152 return tkn.scanHex() 153 } 154 } 155 if ch == 'B' || ch == 'b' { 156 if tkn.peek(1) == '\'' { 157 tkn.skip(2) 158 return tkn.scanBitLiteral() 159 } 160 } 161 // N\'literal' is used to create a string in the national character set 162 if ch == 'N' || ch == 'n' { 163 nxt := tkn.peek(1) 164 if nxt == '\'' || nxt == '"' { 165 tkn.skip(2) 166 return tkn.scanString(nxt, NCHAR_STRING) 167 } 168 } 169 return tkn.scanIdentifier(false) 170 case isDigit(ch): 171 return tkn.scanNumber() 172 case ch == ':': 173 return tkn.scanBindVar() 174 case ch == ';': 175 if tkn.multi { 176 // In multi mode, ';' is treated as EOF. So, we don't advance. 177 // Repeated calls to Scan will keep returning 0 until ParseNext 178 // forces the advance. 179 return 0, "" 180 } 181 tkn.skip(1) 182 return ';', "" 183 case ch == eofChar: 184 return 0, "" 185 default: 186 if ch == '.' && isDigit(tkn.peek(1)) { 187 return tkn.scanNumber() 188 } 189 190 tkn.skip(1) 191 switch ch { 192 case '=', ',', '(', ')', '+', '*', '%', '^', '~': 193 return int(ch), "" 194 case '&': 195 if tkn.cur() == '&' { 196 tkn.skip(1) 197 return AND, "" 198 } 199 return int(ch), "" 200 case '|': 201 if tkn.cur() == '|' { 202 tkn.skip(1) 203 return OR, "" 204 } 205 return int(ch), "" 206 case '?': 207 tkn.posVarIndex++ 208 buf := make([]byte, 0, 8) 209 buf = append(buf, ":v"...) 210 buf = strconv.AppendInt(buf, int64(tkn.posVarIndex), 10) 211 return VALUE_ARG, string(buf) 212 case '.': 213 return int(ch), "" 214 case '/': 215 switch tkn.cur() { 216 case '/': 217 tkn.skip(1) 218 return tkn.scanCommentType1(2) 219 case '*': 220 tkn.skip(1) 221 if tkn.cur() == '!' && !tkn.SkipSpecialComments { 222 tkn.skip(1) 223 return tkn.scanMySQLSpecificComment() 224 } 225 return tkn.scanCommentType2() 226 default: 227 return int(ch), "" 228 } 229 case '#': 230 return tkn.scanCommentType1(1) 231 case '-': 232 switch tkn.cur() { 233 case '-': 234 nextChar := tkn.peek(1) 235 if nextChar == ' ' || nextChar == '\n' || nextChar == '\t' || nextChar == '\r' || nextChar == eofChar { 236 tkn.skip(1) 237 return tkn.scanCommentType1(2) 238 } 239 case '>': 240 tkn.skip(1) 241 if tkn.cur() == '>' { 242 tkn.skip(1) 243 return JSON_UNQUOTE_EXTRACT_OP, "" 244 } 245 return JSON_EXTRACT_OP, "" 246 } 247 return int(ch), "" 248 case '<': 249 switch tkn.cur() { 250 case '>': 251 tkn.skip(1) 252 return NE, "" 253 case '<': 254 tkn.skip(1) 255 return SHIFT_LEFT, "" 256 case '=': 257 tkn.skip(1) 258 switch tkn.cur() { 259 case '>': 260 tkn.skip(1) 261 return NULL_SAFE_EQUAL, "" 262 default: 263 return LE, "" 264 } 265 default: 266 return int(ch), "" 267 } 268 case '>': 269 switch tkn.cur() { 270 case '=': 271 tkn.skip(1) 272 return GE, "" 273 case '>': 274 tkn.skip(1) 275 return SHIFT_RIGHT, "" 276 default: 277 return int(ch), "" 278 } 279 case '!': 280 if tkn.cur() == '=' { 281 tkn.skip(1) 282 return NE, "" 283 } 284 return int(ch), "" 285 case '\'', '"': 286 return tkn.scanString(ch, STRING) 287 case '`': 288 return tkn.scanLiteralIdentifier() 289 default: 290 return LEX_ERROR, string(byte(ch)) 291 } 292 } 293 } 294 295 // skipStatement scans until end of statement. 296 func (tkn *Tokenizer) skipStatement() int { 297 tkn.SkipToEnd = false 298 for { 299 typ, _ := tkn.Scan() 300 if typ == 0 || typ == ';' || typ == LEX_ERROR { 301 return typ 302 } 303 } 304 } 305 306 // skipBlank skips the cursor while it finds whitespace 307 func (tkn *Tokenizer) skipBlank() { 308 ch := tkn.cur() 309 for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' { 310 tkn.skip(1) 311 ch = tkn.cur() 312 } 313 } 314 315 // scanIdentifier scans a language keyword or @-encased variable 316 func (tkn *Tokenizer) scanIdentifier(isVariable bool) (int, string) { 317 start := tkn.Pos 318 tkn.skip(1) 319 320 for { 321 ch := tkn.cur() 322 if !isLetter(ch) && !isDigit(ch) && !(isVariable && isCarat(ch)) { 323 break 324 } 325 tkn.skip(1) 326 } 327 keywordName := tkn.buf[start:tkn.Pos] 328 if keywordID, found := keywordLookupTable.LookupString(keywordName); found { 329 return keywordID, keywordName 330 } 331 // dual must always be case-insensitive 332 if keywordASCIIMatch(keywordName, "dual") { 333 return ID, "dual" 334 } 335 return ID, keywordName 336 } 337 338 // scanHex scans a hex numeral; assumes x' or X' has already been scanned 339 func (tkn *Tokenizer) scanHex() (int, string) { 340 start := tkn.Pos 341 tkn.scanMantissa(16) 342 hex := tkn.buf[start:tkn.Pos] 343 if tkn.cur() != '\'' { 344 return LEX_ERROR, hex 345 } 346 tkn.skip(1) 347 if len(hex)%2 != 0 { 348 return LEX_ERROR, hex 349 } 350 return HEX, hex 351 } 352 353 // scanBitLiteral scans a binary numeric literal; assumes b' or B' has already been scanned 354 func (tkn *Tokenizer) scanBitLiteral() (int, string) { 355 start := tkn.Pos 356 tkn.scanMantissa(2) 357 bit := tkn.buf[start:tkn.Pos] 358 if tkn.cur() != '\'' { 359 return LEX_ERROR, bit 360 } 361 tkn.skip(1) 362 return BIT_LITERAL, bit 363 } 364 365 // scanLiteralIdentifierSlow scans an identifier surrounded by backticks which may 366 // contain escape sequences instead of it. This method is only called from 367 // scanLiteralIdentifier once the first escape sequence is found in the identifier. 368 // The provided `buf` contains the contents of the identifier that have been scanned 369 // so far. 370 func (tkn *Tokenizer) scanLiteralIdentifierSlow(buf *strings.Builder) (int, string) { 371 backTickSeen := true 372 for { 373 if backTickSeen { 374 if tkn.cur() != '`' { 375 break 376 } 377 backTickSeen = false 378 buf.WriteByte('`') 379 tkn.skip(1) 380 continue 381 } 382 // The previous char was not a backtick. 383 switch tkn.cur() { 384 case '`': 385 backTickSeen = true 386 case eofChar: 387 // Premature EOF. 388 return LEX_ERROR, buf.String() 389 default: 390 buf.WriteByte(byte(tkn.cur())) 391 // keep scanning 392 } 393 tkn.skip(1) 394 } 395 return ID, buf.String() 396 } 397 398 // scanLiteralIdentifier scans an identifier enclosed by backticks. If the identifier 399 // is a simple literal, it'll be returned as a slice of the input buffer. If the identifier 400 // contains escape sequences, this function will fall back to scanLiteralIdentifierSlow 401 func (tkn *Tokenizer) scanLiteralIdentifier() (int, string) { 402 start := tkn.Pos 403 for { 404 switch tkn.cur() { 405 case '`': 406 if tkn.peek(1) != '`' { 407 if tkn.Pos == start { 408 return LEX_ERROR, "" 409 } 410 tkn.skip(1) 411 return ID, tkn.buf[start : tkn.Pos-1] 412 } 413 414 var buf strings.Builder 415 buf.WriteString(tkn.buf[start:tkn.Pos]) 416 tkn.skip(1) 417 return tkn.scanLiteralIdentifierSlow(&buf) 418 case eofChar: 419 // Premature EOF. 420 return LEX_ERROR, tkn.buf[start:tkn.Pos] 421 default: 422 tkn.skip(1) 423 } 424 } 425 } 426 427 // scanBindVar scans a bind variable; assumes a ':' has been scanned right before 428 func (tkn *Tokenizer) scanBindVar() (int, string) { 429 start := tkn.Pos 430 token := VALUE_ARG 431 432 tkn.skip(1) 433 // If : is followed by a digit, then it is an offset value arg. Example - :1, :10 434 if isDigit(tkn.cur()) { 435 tkn.scanMantissa(10) 436 return OFFSET_ARG, tkn.buf[start+1 : tkn.Pos] 437 } 438 // If : is followed by another : it is a list arg. Example ::v1, ::list 439 if tkn.cur() == ':' { 440 token = LIST_ARG 441 tkn.skip(1) 442 } 443 if !isLetter(tkn.cur()) { 444 return LEX_ERROR, tkn.buf[start:tkn.Pos] 445 } 446 // If : is followed by a letter, it is a bindvariable. Example :v1, :v2 447 for { 448 ch := tkn.cur() 449 if !isLetter(ch) && !isDigit(ch) && ch != '.' { 450 break 451 } 452 tkn.skip(1) 453 } 454 return token, tkn.buf[start:tkn.Pos] 455 } 456 457 // scanMantissa scans a sequence of numeric characters with the same base. 458 // This is a helper function only called from the numeric scanners 459 func (tkn *Tokenizer) scanMantissa(base int) { 460 for digitVal(tkn.cur()) < base { 461 tkn.skip(1) 462 } 463 } 464 465 // scanNumber scans any SQL numeric literal, either floating point or integer 466 func (tkn *Tokenizer) scanNumber() (int, string) { 467 start := tkn.Pos 468 token := INTEGRAL 469 470 if tkn.cur() == '.' { 471 token = DECIMAL 472 tkn.skip(1) 473 tkn.scanMantissa(10) 474 goto exponent 475 } 476 477 // 0x construct. 478 if tkn.cur() == '0' { 479 tkn.skip(1) 480 if tkn.cur() == 'x' || tkn.cur() == 'X' { 481 token = HEXNUM 482 tkn.skip(1) 483 tkn.scanMantissa(16) 484 goto exit 485 } 486 if tkn.cur() == 'b' || tkn.cur() == 'B' { 487 token = BITNUM 488 tkn.skip(1) 489 tkn.scanMantissa(2) 490 goto exit 491 } 492 } 493 494 tkn.scanMantissa(10) 495 496 if tkn.cur() == '.' { 497 token = DECIMAL 498 tkn.skip(1) 499 tkn.scanMantissa(10) 500 } 501 502 exponent: 503 if tkn.cur() == 'e' || tkn.cur() == 'E' { 504 token = FLOAT 505 tkn.skip(1) 506 if tkn.cur() == '+' || tkn.cur() == '-' { 507 tkn.skip(1) 508 } 509 tkn.scanMantissa(10) 510 } 511 512 exit: 513 if isLetter(tkn.cur()) { 514 // A letter cannot immediately follow a float number. 515 if token == FLOAT || token == DECIMAL { 516 return LEX_ERROR, tkn.buf[start:tkn.Pos] 517 } 518 // A letter seen after a few numbers means that we should parse this 519 // as an identifier and not a number. 520 for { 521 ch := tkn.cur() 522 if !isLetter(ch) && !isDigit(ch) { 523 break 524 } 525 tkn.skip(1) 526 } 527 return ID, tkn.buf[start:tkn.Pos] 528 } 529 530 return token, tkn.buf[start:tkn.Pos] 531 } 532 533 // scanString scans a string surrounded by the given `delim`, which can be 534 // either single or double quotes. Assumes that the given delimiter has just 535 // been scanned. If the skin contains any escape sequences, this function 536 // will fall back to scanStringSlow 537 func (tkn *Tokenizer) scanString(delim uint16, typ int) (int, string) { 538 start := tkn.Pos 539 540 for { 541 switch tkn.cur() { 542 case delim: 543 if tkn.peek(1) != delim { 544 tkn.skip(1) 545 return typ, tkn.buf[start : tkn.Pos-1] 546 } 547 fallthrough 548 549 case '\\': 550 var buffer strings.Builder 551 buffer.WriteString(tkn.buf[start:tkn.Pos]) 552 return tkn.scanStringSlow(&buffer, delim, typ) 553 554 case eofChar: 555 return LEX_ERROR, tkn.buf[start:tkn.Pos] 556 } 557 558 tkn.skip(1) 559 } 560 } 561 562 // scanString scans a string surrounded by the given `delim` and containing escape 563 // sequencse. The given `buffer` contains the contents of the string that have 564 // been scanned so far. 565 func (tkn *Tokenizer) scanStringSlow(buffer *strings.Builder, delim uint16, typ int) (int, string) { 566 for { 567 ch := tkn.cur() 568 if ch == eofChar { 569 // Unterminated string. 570 return LEX_ERROR, buffer.String() 571 } 572 573 if ch != delim && ch != '\\' { 574 // Scan ahead to the next interesting character. 575 start := tkn.Pos 576 for ; tkn.Pos < len(tkn.buf); tkn.Pos++ { 577 ch = uint16(tkn.buf[tkn.Pos]) 578 if ch == delim || ch == '\\' { 579 break 580 } 581 } 582 583 buffer.WriteString(tkn.buf[start:tkn.Pos]) 584 if tkn.Pos >= len(tkn.buf) { 585 // Reached the end of the buffer without finding a delim or 586 // escape character. 587 tkn.skip(1) 588 continue 589 } 590 } 591 tkn.skip(1) // Read one past the delim or escape character. 592 593 if ch == '\\' { 594 if tkn.cur() == eofChar { 595 // String terminates mid escape character. 596 return LEX_ERROR, buffer.String() 597 } 598 // Preserve escaping of % and _ 599 if tkn.cur() == '%' || tkn.cur() == '_' { 600 buffer.WriteByte('\\') 601 ch = tkn.cur() 602 } else if decodedChar := sqltypes.SQLDecodeMap[byte(tkn.cur())]; decodedChar == sqltypes.DontEscape { 603 ch = tkn.cur() 604 } else { 605 ch = uint16(decodedChar) 606 } 607 } else if ch == delim && tkn.cur() != delim { 608 // Correctly terminated string, which is not a double delim. 609 break 610 } 611 612 buffer.WriteByte(byte(ch)) 613 tkn.skip(1) 614 } 615 616 return typ, buffer.String() 617 } 618 619 // scanCommentType1 scans a SQL line-comment, which is applied until the end 620 // of the line. The given prefix length varies based on whether the comment 621 // is started with '//', '--' or '#'. 622 func (tkn *Tokenizer) scanCommentType1(prefixLen int) (int, string) { 623 start := tkn.Pos - prefixLen 624 for tkn.cur() != eofChar { 625 if tkn.cur() == '\n' { 626 tkn.skip(1) 627 break 628 } 629 tkn.skip(1) 630 } 631 return COMMENT, tkn.buf[start:tkn.Pos] 632 } 633 634 // scanCommentType2 scans a '/*' delimited comment; assumes the opening 635 // prefix has already been scanned 636 func (tkn *Tokenizer) scanCommentType2() (int, string) { 637 start := tkn.Pos - 2 638 for { 639 if tkn.cur() == '*' { 640 tkn.skip(1) 641 if tkn.cur() == '/' { 642 tkn.skip(1) 643 break 644 } 645 continue 646 } 647 if tkn.cur() == eofChar { 648 return LEX_ERROR, tkn.buf[start:tkn.Pos] 649 } 650 tkn.skip(1) 651 } 652 return COMMENT, tkn.buf[start:tkn.Pos] 653 } 654 655 // scanMySQLSpecificComment scans a MySQL comment pragma, which always starts with '//*` 656 func (tkn *Tokenizer) scanMySQLSpecificComment() (int, string) { 657 start := tkn.Pos - 3 658 for { 659 if tkn.cur() == '*' { 660 tkn.skip(1) 661 if tkn.cur() == '/' { 662 tkn.skip(1) 663 break 664 } 665 continue 666 } 667 if tkn.cur() == eofChar { 668 return LEX_ERROR, tkn.buf[start:tkn.Pos] 669 } 670 tkn.skip(1) 671 } 672 673 commentVersion, sql := ExtractMysqlComment(tkn.buf[start:tkn.Pos]) 674 675 if mySQLParserVersion >= commentVersion { 676 // Only add the special comment to the tokenizer if the version of MySQL is higher or equal to the comment version 677 tkn.specialComment = NewStringTokenizer(sql) 678 } 679 680 return tkn.Scan() 681 } 682 683 func (tkn *Tokenizer) cur() uint16 { 684 return tkn.peek(0) 685 } 686 687 func (tkn *Tokenizer) skip(dist int) { 688 tkn.Pos += dist 689 } 690 691 func (tkn *Tokenizer) peek(dist int) uint16 { 692 if tkn.Pos+dist >= len(tkn.buf) { 693 return eofChar 694 } 695 return uint16(tkn.buf[tkn.Pos+dist]) 696 } 697 698 // reset clears any internal state. 699 func (tkn *Tokenizer) reset() { 700 tkn.ParseTree = nil 701 tkn.partialDDL = nil 702 tkn.specialComment = nil 703 tkn.posVarIndex = 0 704 tkn.nesting = 0 705 tkn.SkipToEnd = false 706 } 707 708 func isLetter(ch uint16) bool { 709 return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '$' 710 } 711 712 func isCarat(ch uint16) bool { 713 return ch == '.' || ch == '\'' || ch == '"' || ch == '`' 714 } 715 716 func digitVal(ch uint16) int { 717 switch { 718 case '0' <= ch && ch <= '9': 719 return int(ch) - '0' 720 case 'a' <= ch && ch <= 'f': 721 return int(ch) - 'a' + 10 722 case 'A' <= ch && ch <= 'F': 723 return int(ch) - 'A' + 10 724 } 725 return 16 // larger than any legal digit val 726 } 727 728 func isDigit(ch uint16) bool { 729 return '0' <= ch && ch <= '9' 730 }