github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/scanner/scan.go (about) 1 // Copyright 2015 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package scanner 12 13 import ( 14 "fmt" 15 "go/constant" 16 "go/token" 17 "strconv" 18 "strings" 19 "unicode/utf8" 20 "unsafe" 21 22 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/lexbase" 23 ) 24 25 const eof = -1 26 const errUnterminated = "unterminated string" 27 const errInvalidUTF8 = "invalid UTF-8 byte sequence" 28 const errInvalidHexNumeric = "invalid hexadecimal numeric literal" 29 const singleQuote = '\'' 30 const identQuote = '"' 31 32 // NewNumValFn allows us to use tree.NewNumVal without a dependency on tree. 33 var NewNumValFn = func(constant.Value, string, bool) interface{} { 34 return struct{}{} 35 } 36 37 // NewPlaceholderFn allows us to use tree.NewPlaceholder without a dependency on 38 // tree. 39 var NewPlaceholderFn = func(string) (interface{}, error) { 40 return struct{}{}, nil 41 } 42 43 // ScanSymType is the interface for accessing the fields of a yacc symType. 44 type ScanSymType interface { 45 ID() int32 46 SetID(int32) 47 Pos() int32 48 SetPos(int32) 49 Str() string 50 SetStr(string) 51 UnionVal() interface{} 52 SetUnionVal(interface{}) 53 } 54 55 // Scanner lexes SQL statements. 56 type Scanner struct { 57 in string 58 pos int 59 bytesPrealloc []byte 60 61 // Comments is the list of parsed comments from the SQL statement. 62 Comments []string 63 64 // lastAttemptedID indicates the ID of the last attempted 65 // token. Used to recognizd which token an error was encountered 66 // on. 67 lastAttemptedID int32 68 // quoted indicates if the last identifier scanned was 69 // quoted. Used to distinguish between quoted and non-quoted in 70 // Inspect. 71 quoted bool 72 } 73 74 // SQLScanner is a scanner with a SQL specific scan function 75 type SQLScanner struct { 76 Scanner 77 } 78 79 // In returns the input string. 80 func (s *Scanner) In() string { 81 return s.in 82 } 83 84 // Pos returns the current position being lexed. 85 func (s *Scanner) Pos() int { 86 return s.pos 87 } 88 89 // Init initializes a new Scanner that will process str. 90 func (s *Scanner) Init(str string) { 91 s.in = str 92 s.pos = 0 93 // Preallocate some buffer space for identifiers etc. 94 s.bytesPrealloc = make([]byte, len(str)) 95 } 96 97 // Cleanup is used to avoid holding on to memory unnecessarily (for the cases 98 // where we reuse a Scanner). 99 func (s *Scanner) Cleanup() { 100 s.bytesPrealloc = nil 101 } 102 103 func (s *Scanner) allocBytes(length int) []byte { 104 if cap(s.bytesPrealloc) >= length { 105 res := s.bytesPrealloc[:length:length] 106 s.bytesPrealloc = s.bytesPrealloc[length:cap(s.bytesPrealloc)] 107 return res 108 } 109 return make([]byte, length) 110 } 111 112 // buffer returns an empty []byte buffer that can be appended to. Any unused 113 // portion can be returned later using returnBuffer. 114 func (s *Scanner) buffer() []byte { 115 buf := s.bytesPrealloc[:0] 116 s.bytesPrealloc = nil 117 return buf 118 } 119 120 // returnBuffer returns the unused portion of buf to the Scanner, to be used for 121 // future allocBytes() or buffer() calls. The caller must not use buf again. 122 func (s *Scanner) returnBuffer(buf []byte) { 123 if len(buf) < cap(buf) { 124 s.bytesPrealloc = buf[len(buf):] 125 } 126 } 127 128 // finishString casts the given buffer to a string and returns the unused 129 // portion of the buffer. The caller must not use buf again. 130 func (s *Scanner) finishString(buf []byte) string { 131 str := *(*string)(unsafe.Pointer(&buf)) 132 s.returnBuffer(buf) 133 return str 134 } 135 136 func (s *Scanner) scanSetup(lval ScanSymType) (int, bool) { 137 lval.SetID(0) 138 lval.SetPos(int32(s.pos)) 139 lval.SetStr("EOF") 140 s.quoted = false 141 s.lastAttemptedID = 0 142 143 if _, ok := s.skipWhitespace(lval, true); !ok { 144 return 0, true 145 } 146 147 ch := s.next() 148 if ch == eof { 149 lval.SetPos(int32(s.pos)) 150 return ch, false 151 } 152 153 lval.SetID(int32(ch)) 154 lval.SetPos(int32(s.pos - 1)) 155 lval.SetStr(s.in[lval.Pos():s.pos]) 156 s.lastAttemptedID = int32(ch) 157 return ch, false 158 } 159 160 // Scan scans the next token and populates its information into lval. 161 func (s *SQLScanner) Scan(lval ScanSymType) { 162 ch, skipWhiteSpace := s.scanSetup(lval) 163 164 if skipWhiteSpace { 165 return 166 } 167 168 switch ch { 169 case '$': 170 // placeholder? $[0-9]+ 171 if lexbase.IsDigit(s.peek()) { 172 s.scanPlaceholder(lval) 173 return 174 } else if s.scanDollarQuotedString(lval) { 175 lval.SetID(lexbase.SCONST) 176 return 177 } 178 return 179 180 case identQuote: 181 // "[^"]" 182 s.lastAttemptedID = int32(lexbase.IDENT) 183 s.quoted = true 184 if s.scanString(lval, identQuote, false /* allowEscapes */, true /* requireUTF8 */) { 185 lval.SetID(lexbase.IDENT) 186 } 187 return 188 189 case singleQuote: 190 // '[^']' 191 s.lastAttemptedID = int32(lexbase.SCONST) 192 if s.scanString(lval, ch, false /* allowEscapes */, true /* requireUTF8 */) { 193 lval.SetID(lexbase.SCONST) 194 } 195 return 196 197 case 'b': 198 // Bytes? 199 if s.peek() == singleQuote { 200 // b'[^']' 201 s.lastAttemptedID = int32(lexbase.BCONST) 202 s.pos++ 203 if s.scanString(lval, singleQuote, true /* allowEscapes */, false /* requireUTF8 */) { 204 lval.SetID(lexbase.BCONST) 205 } 206 return 207 } 208 s.scanIdent(lval) 209 return 210 211 case 'r', 'R': 212 s.scanIdent(lval) 213 return 214 215 case 'e', 'E': 216 // Escaped string? 217 if s.peek() == singleQuote { 218 // [eE]'[^']' 219 s.lastAttemptedID = int32(lexbase.SCONST) 220 s.pos++ 221 if s.scanString(lval, singleQuote, true /* allowEscapes */, true /* requireUTF8 */) { 222 lval.SetID(lexbase.SCONST) 223 } 224 return 225 } 226 s.scanIdent(lval) 227 return 228 229 case 'B': 230 // Bit array literal? 231 if s.peek() == singleQuote { 232 // B'[01]*' 233 s.pos++ 234 s.scanBitString(lval, singleQuote) 235 return 236 } 237 s.scanIdent(lval) 238 return 239 240 case 'x', 'X': 241 // Hex literal? 242 if s.peek() == singleQuote { 243 // [xX]'[a-f0-9]' 244 s.pos++ 245 s.scanHexString(lval, singleQuote) 246 return 247 } 248 s.scanIdent(lval) 249 return 250 251 case '.': 252 switch t := s.peek(); { 253 case t == '.': // .. 254 s.pos++ 255 lval.SetID(lexbase.DOT_DOT) 256 return 257 case lexbase.IsDigit(t): 258 s.lastAttemptedID = int32(lexbase.FCONST) 259 s.scanNumber(lval, ch) 260 return 261 } 262 return 263 264 case '!': 265 switch s.peek() { 266 case '=': // != 267 s.pos++ 268 lval.SetID(lexbase.NOT_EQUALS) 269 return 270 case '~': // !~ 271 s.pos++ 272 switch s.peek() { 273 case '*': // !~* 274 s.pos++ 275 lval.SetID(lexbase.NOT_REGIMATCH) 276 return 277 } 278 lval.SetID(lexbase.NOT_REGMATCH) 279 return 280 } 281 return 282 283 case '?': 284 switch s.peek() { 285 case '?': // ?? 286 s.pos++ 287 lval.SetID(lexbase.HELPTOKEN) 288 return 289 case '|': // ?| 290 s.pos++ 291 lval.SetID(lexbase.JSON_SOME_EXISTS) 292 return 293 case '&': // ?& 294 s.pos++ 295 lval.SetID(lexbase.JSON_ALL_EXISTS) 296 return 297 } 298 return 299 300 case '<': 301 switch s.peek() { 302 case '<': // << 303 s.pos++ 304 switch s.peek() { 305 case '=': // <<= 306 s.pos++ 307 lval.SetID(lexbase.INET_CONTAINED_BY_OR_EQUALS) 308 return 309 } 310 lval.SetID(lexbase.LSHIFT) 311 return 312 case '>': // <> 313 s.pos++ 314 lval.SetID(lexbase.NOT_EQUALS) 315 return 316 case '=': // <= 317 s.pos++ 318 lval.SetID(lexbase.LESS_EQUALS) 319 return 320 case '@': // <@ 321 s.pos++ 322 lval.SetID(lexbase.CONTAINED_BY) 323 return 324 } 325 return 326 327 case '>': 328 switch s.peek() { 329 case '>': // >> 330 s.pos++ 331 switch s.peek() { 332 case '=': // >>= 333 s.pos++ 334 lval.SetID(lexbase.INET_CONTAINS_OR_EQUALS) 335 return 336 } 337 lval.SetID(lexbase.RSHIFT) 338 return 339 case '=': // >= 340 s.pos++ 341 lval.SetID(lexbase.GREATER_EQUALS) 342 return 343 } 344 return 345 346 case ':': 347 switch s.peek() { 348 case ':': // :: 349 if s.peekN(1) == ':' { 350 // ::: 351 s.pos += 2 352 lval.SetID(lexbase.TYPEANNOTATE) 353 return 354 } 355 s.pos++ 356 lval.SetID(lexbase.TYPECAST) 357 return 358 } 359 return 360 361 case '|': 362 switch s.peek() { 363 case '|': // || 364 s.pos++ 365 switch s.peek() { 366 case '/': // ||/ 367 s.pos++ 368 lval.SetID(lexbase.CBRT) 369 return 370 } 371 lval.SetID(lexbase.CONCAT) 372 return 373 case '/': // |/ 374 s.pos++ 375 lval.SetID(lexbase.SQRT) 376 return 377 } 378 return 379 380 case '/': 381 switch s.peek() { 382 case '/': // // 383 s.pos++ 384 lval.SetID(lexbase.FLOORDIV) 385 return 386 } 387 return 388 389 case '~': 390 switch s.peek() { 391 case '*': // ~* 392 s.pos++ 393 lval.SetID(lexbase.REGIMATCH) 394 return 395 } 396 return 397 398 case '@': 399 switch s.peek() { 400 case '>': // @> 401 s.pos++ 402 lval.SetID(lexbase.CONTAINS) 403 return 404 case '@': // @@ 405 s.pos++ 406 lval.SetID(lexbase.AT_AT) 407 return 408 } 409 return 410 411 case '&': 412 switch s.peek() { 413 case '&': // && 414 s.pos++ 415 lval.SetID(lexbase.AND_AND) 416 return 417 } 418 return 419 420 case '-': 421 switch s.peek() { 422 case '>': // -> 423 if s.peekN(1) == '>' { 424 // ->> 425 s.pos += 2 426 lval.SetID(lexbase.FETCHTEXT) 427 return 428 } 429 s.pos++ 430 lval.SetID(lexbase.FETCHVAL) 431 return 432 } 433 return 434 435 case '#': 436 switch s.peek() { 437 case '>': // #> 438 if s.peekN(1) == '>' { 439 // #>> 440 s.pos += 2 441 lval.SetID(lexbase.FETCHTEXT_PATH) 442 return 443 } 444 s.pos++ 445 lval.SetID(lexbase.FETCHVAL_PATH) 446 return 447 case '-': // #- 448 s.pos++ 449 lval.SetID(lexbase.REMOVE_PATH) 450 return 451 } 452 return 453 454 default: 455 if lexbase.IsDigit(ch) { 456 s.lastAttemptedID = int32(lexbase.ICONST) 457 s.scanNumber(lval, ch) 458 return 459 } 460 if lexbase.IsIdentStart(ch) { 461 s.scanIdent(lval) 462 return 463 } 464 } 465 466 // Everything else is a single character token which we already initialized 467 // lval for above. 468 } 469 470 func (s *Scanner) peek() int { 471 if s.pos >= len(s.in) { 472 return eof 473 } 474 return int(s.in[s.pos]) 475 } 476 477 func (s *Scanner) peekN(n int) int { 478 pos := s.pos + n 479 if pos >= len(s.in) { 480 return eof 481 } 482 return int(s.in[pos]) 483 } 484 485 func (s *Scanner) next() int { 486 ch := s.peek() 487 if ch != eof { 488 s.pos++ 489 } 490 return ch 491 } 492 493 func (s *Scanner) skipWhitespace(lval ScanSymType, allowComments bool) (newline, ok bool) { 494 newline = false 495 for { 496 startPos := s.pos 497 ch := s.peek() 498 if ch == '\n' { 499 s.pos++ 500 newline = true 501 continue 502 } 503 if ch == ' ' || ch == '\t' || ch == '\r' || ch == '\f' { 504 s.pos++ 505 continue 506 } 507 if allowComments { 508 if present, cok := s.ScanComment(lval); !cok { 509 return false, false 510 } else if present { 511 // Mark down the comments that we found. 512 s.Comments = append(s.Comments, s.in[startPos:s.pos]) 513 continue 514 } 515 } 516 break 517 } 518 return newline, true 519 } 520 521 // ScanComment scans the input as a comment. 522 func (s *Scanner) ScanComment(lval ScanSymType) (present, ok bool) { 523 start := s.pos 524 ch := s.peek() 525 526 if ch == '/' { 527 s.pos++ 528 if s.peek() != '*' { 529 s.pos-- 530 return false, true 531 } 532 s.pos++ 533 depth := 1 534 for { 535 switch s.next() { 536 case '*': 537 if s.peek() == '/' { 538 s.pos++ 539 depth-- 540 if depth == 0 { 541 return true, true 542 } 543 continue 544 } 545 546 case '/': 547 if s.peek() == '*' { 548 s.pos++ 549 depth++ 550 continue 551 } 552 553 case eof: 554 lval.SetID(lexbase.ERROR) 555 lval.SetPos(int32(start)) 556 lval.SetStr("unterminated comment") 557 return false, false 558 } 559 } 560 } 561 562 if ch == '-' { 563 s.pos++ 564 if s.peek() != '-' { 565 s.pos-- 566 return false, true 567 } 568 for { 569 switch s.next() { 570 case eof, '\n': 571 return true, true 572 } 573 } 574 } 575 576 return false, true 577 } 578 579 func (s *Scanner) lowerCaseAndNormalizeIdent(lval ScanSymType) { 580 s.lastAttemptedID = int32(lexbase.IDENT) 581 s.pos-- 582 start := s.pos 583 isASCII := true 584 isLower := true 585 586 // Consume the Scanner character by character, stopping after the last legal 587 // identifier character. By the end of this function, we need to 588 // lowercase and unicode normalize this identifier, which is expensive if 589 // there are actual unicode characters in it. If not, it's quite cheap - and 590 // if it's lowercase already, there's no work to do. Therefore, we keep track 591 // of whether the string is only ASCII or only ASCII lowercase for later. 592 for { 593 ch := s.peek() 594 if ch >= utf8.RuneSelf { 595 isASCII = false 596 } else if ch >= 'A' && ch <= 'Z' { 597 isLower = false 598 } 599 600 if !lexbase.IsIdentMiddle(ch) { 601 break 602 } 603 604 s.pos++ 605 } 606 607 if isLower && isASCII { 608 // Already lowercased - nothing to do. 609 lval.SetStr(s.in[start:s.pos]) 610 } else if isASCII { 611 // We know that the identifier we've seen so far is ASCII, so we don't need 612 // to unicode normalize. Instead, just lowercase as normal. 613 b := s.allocBytes(s.pos - start) 614 _ = b[s.pos-start-1] // For bounds check elimination. 615 for i, c := range s.in[start:s.pos] { 616 if c >= 'A' && c <= 'Z' { 617 c += 'a' - 'A' 618 } 619 b[i] = byte(c) 620 } 621 lval.SetStr(*(*string)(unsafe.Pointer(&b))) 622 } else { 623 // The string has unicode in it. No choice but to run Normalize. 624 lval.SetStr(lexbase.NormalizeName(s.in[start:s.pos])) 625 } 626 } 627 628 func (s *Scanner) scanIdent(lval ScanSymType) { 629 s.lowerCaseAndNormalizeIdent(lval) 630 631 isExperimental := false 632 kw := lval.Str() 633 switch { 634 case strings.HasPrefix(lval.Str(), "experimental_"): 635 kw = lval.Str()[13:] 636 isExperimental = true 637 case strings.HasPrefix(lval.Str(), "testing_"): 638 kw = lval.Str()[8:] 639 isExperimental = true 640 } 641 lval.SetID(lexbase.GetKeywordID(kw)) 642 if lval.ID() != lexbase.IDENT { 643 if isExperimental { 644 if _, ok := lexbase.AllowedExperimental[kw]; !ok { 645 // If the parsed token is not on the allowlisted set of keywords, 646 // then it might have been intended to be parsed as something else. 647 // In that case, re-tokenize the original string. 648 lval.SetID(lexbase.GetKeywordID(lval.Str())) 649 } else { 650 // It is a allowlisted keyword, so remember the shortened 651 // keyword for further processing. 652 lval.SetStr(kw) 653 } 654 } 655 } else { 656 // If the word after experimental_ or testing_ is an identifier, 657 // then we might have classified it incorrectly after removing the 658 // experimental_/testing_ prefix. 659 lval.SetID(lexbase.GetKeywordID(lval.Str())) 660 } 661 } 662 663 func (s *Scanner) scanNumber(lval ScanSymType, ch int) { 664 start := s.pos - 1 665 isHex := false 666 hasDecimal := ch == '.' 667 hasExponent := false 668 669 for { 670 ch := s.peek() 671 if (isHex && lexbase.IsHexDigit(ch)) || lexbase.IsDigit(ch) { 672 s.pos++ 673 continue 674 } 675 if ch == 'x' || ch == 'X' { 676 if isHex || s.in[start] != '0' || s.pos != start+1 { 677 lval.SetID(lexbase.ERROR) 678 lval.SetStr(errInvalidHexNumeric) 679 return 680 } 681 s.pos++ 682 isHex = true 683 continue 684 } 685 if isHex { 686 break 687 } 688 if ch == '.' { 689 if hasDecimal || hasExponent { 690 break 691 } 692 s.pos++ 693 if s.peek() == '.' { 694 // Found ".." while scanning a number: back up to the end of the 695 // integer. 696 s.pos-- 697 break 698 } 699 hasDecimal = true 700 continue 701 } 702 if ch == 'e' || ch == 'E' { 703 if hasExponent { 704 break 705 } 706 hasExponent = true 707 s.pos++ 708 ch = s.peek() 709 if ch == '-' || ch == '+' { 710 s.pos++ 711 } 712 ch = s.peek() 713 if !lexbase.IsDigit(ch) { 714 lval.SetID(lexbase.ERROR) 715 lval.SetStr("invalid floating point literal") 716 return 717 } 718 continue 719 } 720 break 721 } 722 723 // Disallow identifier after numerical constants e.g. "124foo". 724 if lexbase.IsIdentStart(s.peek()) { 725 lval.SetID(lexbase.ERROR) 726 lval.SetStr(fmt.Sprintf("trailing junk after numeric literal at or near %q", s.in[start:s.pos+1])) 727 return 728 } 729 730 lval.SetStr(s.in[start:s.pos]) 731 if hasDecimal || hasExponent { 732 lval.SetID(lexbase.FCONST) 733 floatConst := constant.MakeFromLiteral(lval.Str(), token.FLOAT, 0) 734 if floatConst.Kind() == constant.Unknown { 735 lval.SetID(lexbase.ERROR) 736 lval.SetStr(fmt.Sprintf("could not make constant float from literal %q", lval.Str())) 737 return 738 } 739 lval.SetUnionVal(NewNumValFn(floatConst, lval.Str(), false /* negative */)) 740 } else { 741 if isHex && s.pos == start+2 { 742 lval.SetID(lexbase.ERROR) 743 lval.SetStr(errInvalidHexNumeric) 744 return 745 } 746 747 // Strip off leading zeros from non-hex (decimal) literals so that 748 // constant.MakeFromLiteral doesn't inappropriately interpret the 749 // string as an octal literal. Note: we can't use strings.TrimLeft 750 // here, because it will truncate '0' to ''. 751 if !isHex { 752 for len(lval.Str()) > 1 && lval.Str()[0] == '0' { 753 lval.SetStr(lval.Str()[1:]) 754 } 755 } 756 757 lval.SetID(lexbase.ICONST) 758 intConst := constant.MakeFromLiteral(lval.Str(), token.INT, 0) 759 if intConst.Kind() == constant.Unknown { 760 lval.SetID(lexbase.ERROR) 761 lval.SetStr(fmt.Sprintf("could not make constant int from literal %q", lval.Str())) 762 return 763 } 764 lval.SetUnionVal(NewNumValFn(intConst, lval.Str(), false /* negative */)) 765 } 766 } 767 768 func (s *Scanner) scanPlaceholder(lval ScanSymType) { 769 s.lastAttemptedID = int32(lexbase.PLACEHOLDER) 770 start := s.pos 771 for lexbase.IsDigit(s.peek()) { 772 s.pos++ 773 } 774 lval.SetStr(s.in[start:s.pos]) 775 776 placeholder, err := NewPlaceholderFn(lval.Str()) 777 if err != nil { 778 lval.SetID(lexbase.ERROR) 779 lval.SetStr(err.Error()) 780 return 781 } 782 lval.SetID(lexbase.PLACEHOLDER) 783 lval.SetUnionVal(placeholder) 784 } 785 786 // scanHexString scans the content inside x'....'. 787 func (s *Scanner) scanHexString(lval ScanSymType, ch int) bool { 788 s.lastAttemptedID = int32(lexbase.BCONST) 789 buf := s.buffer() 790 791 var curbyte byte 792 bytep := 0 793 const errInvalidBytesLiteral = "invalid hexadecimal bytes literal" 794 outer: 795 for { 796 b := s.next() 797 switch b { 798 case ch: 799 newline, ok := s.skipWhitespace(lval, false) 800 if !ok { 801 return false 802 } 803 // SQL allows joining adjacent strings separated by whitespace 804 // as long as that whitespace contains at least one 805 // newline. Kind of strange to require the newline, but that 806 // is the standard. 807 if s.peek() == ch && newline { 808 s.pos++ 809 continue 810 } 811 break outer 812 813 case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': 814 curbyte = (curbyte << 4) | byte(b-'0') 815 case 'a', 'b', 'c', 'd', 'e', 'f': 816 curbyte = (curbyte << 4) | byte(b-'a'+10) 817 case 'A', 'B', 'C', 'D', 'E', 'F': 818 curbyte = (curbyte << 4) | byte(b-'A'+10) 819 default: 820 lval.SetID(lexbase.ERROR) 821 lval.SetStr(errInvalidBytesLiteral) 822 return false 823 } 824 bytep++ 825 826 if bytep > 1 { 827 buf = append(buf, curbyte) 828 bytep = 0 829 curbyte = 0 830 } 831 } 832 833 if bytep != 0 { 834 lval.SetID(lexbase.ERROR) 835 lval.SetStr(errInvalidBytesLiteral) 836 return false 837 } 838 839 lval.SetID(lexbase.BCONST) 840 lval.SetStr(s.finishString(buf)) 841 return true 842 } 843 844 // scanBitString scans the content inside B'....'. 845 func (s *Scanner) scanBitString(lval ScanSymType, ch int) bool { 846 s.lastAttemptedID = int32(lexbase.BITCONST) 847 buf := s.buffer() 848 outer: 849 for { 850 b := s.next() 851 switch b { 852 case ch: 853 newline, ok := s.skipWhitespace(lval, false) 854 if !ok { 855 return false 856 } 857 // SQL allows joining adjacent strings separated by whitespace 858 // as long as that whitespace contains at least one 859 // newline. Kind of strange to require the newline, but that 860 // is the standard. 861 if s.peek() == ch && newline { 862 s.pos++ 863 continue 864 } 865 break outer 866 867 case '0', '1': 868 buf = append(buf, byte(b)) 869 default: 870 lval.SetID(lexbase.ERROR) 871 lval.SetStr(fmt.Sprintf(`"%c" is not a valid binary digit`, rune(b))) 872 return false 873 } 874 } 875 876 lval.SetID(lexbase.BITCONST) 877 lval.SetStr(s.finishString(buf)) 878 return true 879 } 880 881 // scanString scans the content inside '...'. This is used for simple 882 // string literals '...' but also e'....' and b'...'. For x'...', see 883 // scanHexString(). 884 func (s *Scanner) scanString(lval ScanSymType, ch int, allowEscapes, requireUTF8 bool) bool { 885 buf := s.buffer() 886 var runeTmp [utf8.UTFMax]byte 887 start := s.pos 888 outer: 889 for { 890 switch s.next() { 891 case ch: 892 buf = append(buf, s.in[start:s.pos-1]...) 893 if s.peek() == ch { 894 // Double quote is translated into a single quote that is part of the 895 // string. 896 start = s.pos 897 s.pos++ 898 continue 899 } 900 901 newline, ok := s.skipWhitespace(lval, false) 902 if !ok { 903 return false 904 } 905 906 // SQL allows joining adjacent single-quoted strings separated by 907 // whitespace as long as that whitespace contains at least one 908 // newline. Kind of strange to require the newline, but that is the 909 // standard. 910 if ch == singleQuote && s.peek() == singleQuote && newline { 911 s.pos++ 912 start = s.pos 913 continue 914 } 915 break outer 916 917 case '\\': 918 t := s.peek() 919 920 if allowEscapes { 921 buf = append(buf, s.in[start:s.pos-1]...) 922 if t == ch { 923 start = s.pos 924 s.pos++ 925 continue 926 } 927 928 switch t { 929 case 'a', 'b', 'f', 'n', 'r', 't', 'v', 'x', 'X', 'u', 'U', '\\', 930 '0', '1', '2', '3', '4', '5', '6', '7': 931 var tmp string 932 if t == 'X' && len(s.in[s.pos:]) >= 3 { 933 // UnquoteChar doesn't handle 'X' so we create a temporary string 934 // for it to parse. 935 tmp = "\\x" + s.in[s.pos+1:s.pos+3] 936 } else { 937 tmp = s.in[s.pos-1:] 938 } 939 v, multibyte, tail, err := strconv.UnquoteChar(tmp, byte(ch)) 940 if err != nil { 941 lval.SetID(lexbase.ERROR) 942 lval.SetStr(err.Error()) 943 return false 944 } 945 if v < utf8.RuneSelf || !multibyte { 946 buf = append(buf, byte(v)) 947 } else { 948 n := utf8.EncodeRune(runeTmp[:], v) 949 buf = append(buf, runeTmp[:n]...) 950 } 951 s.pos += len(tmp) - len(tail) - 1 952 start = s.pos 953 continue 954 } 955 956 // If we end up here, it's a redundant escape - simply drop the 957 // backslash. For example, e'\"' is equivalent to e'"', and 958 // e'\d\b' to e'd\b'. This is what Postgres does: 959 // http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE 960 start = s.pos 961 } 962 963 case eof: 964 lval.SetID(lexbase.ERROR) 965 lval.SetStr(errUnterminated) 966 return false 967 } 968 } 969 970 if requireUTF8 && !utf8.Valid(buf) { 971 lval.SetID(lexbase.ERROR) 972 lval.SetStr(errInvalidUTF8) 973 return false 974 } 975 976 if ch == identQuote { 977 lval.SetStr(lexbase.NormalizeString(s.finishString(buf))) 978 } else { 979 lval.SetStr(s.finishString(buf)) 980 } 981 return true 982 } 983 984 // scanDollarQuotedString scans for so called dollar-quoted strings, which start/end with either $$ or $tag$, where 985 // tag is some arbitrary string. e.g. $$a string$$ or $escaped$a string$escaped$. 986 func (s *Scanner) scanDollarQuotedString(lval ScanSymType) bool { 987 s.lastAttemptedID = int32(lexbase.SCONST) 988 buf := s.buffer() 989 start := s.pos 990 991 foundStartTag := false 992 possibleEndTag := false 993 startTagIndex := -1 994 var startTag string 995 996 outer: 997 for { 998 ch := s.peek() 999 switch ch { 1000 case '$': 1001 s.pos++ 1002 if foundStartTag { 1003 if possibleEndTag { 1004 if len(startTag) == startTagIndex { 1005 // Found end tag. 1006 buf = append(buf, s.in[start+len(startTag)+1:s.pos-len(startTag)-2]...) 1007 break outer 1008 } else { 1009 // Was not the end tag but the current $ might be the start of the end tag we are looking for, so 1010 // just reset the startTagIndex. 1011 startTagIndex = 0 1012 } 1013 } else { 1014 possibleEndTag = true 1015 startTagIndex = 0 1016 } 1017 } else { 1018 startTag = s.in[start : s.pos-1] 1019 foundStartTag = true 1020 } 1021 1022 case eof: 1023 if foundStartTag { 1024 // A start tag was found, therefore we expect an end tag before the eof, otherwise it is an error. 1025 lval.SetID(lexbase.ERROR) 1026 lval.SetStr(errUnterminated) 1027 } else { 1028 // This is not a dollar-quoted string, reset the pos back to the start. 1029 s.pos = start 1030 } 1031 return false 1032 1033 default: 1034 // If we haven't found a start tag yet, check whether the current characters is a valid for a tag. 1035 if !foundStartTag && !lexbase.IsIdentStart(ch) && !lexbase.IsDigit(ch) { 1036 return false 1037 } 1038 s.pos++ 1039 if possibleEndTag { 1040 // Check whether this could be the end tag. 1041 if startTagIndex >= len(startTag) || ch != int(startTag[startTagIndex]) { 1042 // This is not the end tag we are looking for. 1043 possibleEndTag = false 1044 startTagIndex = -1 1045 } else { 1046 startTagIndex++ 1047 } 1048 } 1049 } 1050 } 1051 1052 if !utf8.Valid(buf) { 1053 lval.SetID(lexbase.ERROR) 1054 lval.SetStr(errInvalidUTF8) 1055 return false 1056 } 1057 1058 lval.SetStr(s.finishString(buf)) 1059 return true 1060 } 1061 1062 // HasMultipleStatements returns true if the sql string contains more than one 1063 // statements. An error is returned if an invalid token was encountered. 1064 func HasMultipleStatements(sql string) (multipleStmt bool, err error) { 1065 var s SQLScanner 1066 var lval fakeSym 1067 s.Init(sql) 1068 count := 0 1069 for { 1070 done, hasToks, err := s.scanOne(&lval) 1071 if err != nil { 1072 return false, err 1073 } 1074 if hasToks { 1075 count++ 1076 } 1077 if done || count > 1 { 1078 break 1079 } 1080 } 1081 return count > 1, nil 1082 } 1083 1084 // scanOne is a simplified version of (*Parser).scanOneStmt() for use 1085 // by HasMultipleStatements(). 1086 func (s *SQLScanner) scanOne(lval *fakeSym) (done, hasToks bool, err error) { 1087 // Scan the first token. 1088 for { 1089 s.Scan(lval) 1090 if lval.id == 0 { 1091 return true, false, nil 1092 } 1093 if lval.id != ';' { 1094 break 1095 } 1096 } 1097 1098 var preValID int32 1099 // This is used to track the degree of nested `BEGIN ATOMIC ... END` function 1100 // body context. When greater than zero, it means that we're scanning through 1101 // the function body of a `CREATE FUNCTION` statement. ';' character is only 1102 // a separator of sql statements within the body instead of a finishing line 1103 // of the `CREATE FUNCTION` statement. 1104 curFuncBodyCnt := 0 1105 for { 1106 if lval.id == lexbase.ERROR { 1107 return true, true, fmt.Errorf("scan error: %s", lval.s) 1108 } 1109 preValID = lval.id 1110 s.Scan(lval) 1111 if preValID == lexbase.BEGIN && lval.id == lexbase.ATOMIC { 1112 curFuncBodyCnt++ 1113 } 1114 if curFuncBodyCnt > 0 && lval.id == lexbase.END { 1115 curFuncBodyCnt-- 1116 } 1117 if lval.id == 0 || (curFuncBodyCnt == 0 && lval.id == ';') { 1118 return (lval.id == 0), true, nil 1119 } 1120 } 1121 } 1122 1123 // LastLexicalToken returns the last lexical token. If the string has no lexical 1124 // tokens, returns 0 and ok=false. 1125 func LastLexicalToken(sql string) (lastTok int, ok bool) { 1126 var s SQLScanner 1127 var lval fakeSym 1128 s.Init(sql) 1129 for { 1130 last := lval.ID() 1131 s.Scan(&lval) 1132 if lval.ID() == 0 { 1133 return int(last), last != 0 1134 } 1135 } 1136 } 1137 1138 // FirstLexicalToken returns the first lexical token. 1139 // Returns 0 if there is no token. 1140 func FirstLexicalToken(sql string) (tok int) { 1141 var s SQLScanner 1142 var lval fakeSym 1143 s.Init(sql) 1144 s.Scan(&lval) 1145 id := lval.ID() 1146 return int(id) 1147 } 1148 1149 // fakeSym is a simplified symbol type for use by 1150 // HasMultipleStatements. 1151 type fakeSym struct { 1152 id int32 1153 pos int32 1154 s string 1155 } 1156 1157 var _ ScanSymType = (*fakeSym)(nil) 1158 1159 func (s fakeSym) ID() int32 { return s.id } 1160 func (s *fakeSym) SetID(id int32) { s.id = id } 1161 func (s fakeSym) Pos() int32 { return s.pos } 1162 func (s *fakeSym) SetPos(p int32) { s.pos = p } 1163 func (s fakeSym) Str() string { return s.s } 1164 func (s *fakeSym) SetStr(v string) { s.s = v } 1165 func (s fakeSym) UnionVal() interface{} { return nil } 1166 func (s fakeSym) SetUnionVal(v interface{}) {} 1167 1168 // InspectToken is the type of token that can be scanned by Inspect. 1169 type InspectToken struct { 1170 ID int32 1171 MaybeID int32 1172 Start int32 1173 End int32 1174 Str string 1175 Quoted bool 1176 } 1177 1178 // Inspect analyses the string and returns the tokens found in it. If 1179 // an incomplete token was encountered at the end, an InspectToken 1180 // entry with ID -1 is appended. 1181 // 1182 // If a syntax error was encountered, it is returned as a token with 1183 // type ERROR. 1184 // 1185 // See TestInspect and the examples in testdata/inspect for more details. 1186 func Inspect(sql string) []InspectToken { 1187 var s SQLScanner 1188 var lval fakeSym 1189 var tokens []InspectToken 1190 s.Init(sql) 1191 for { 1192 s.Scan(&lval) 1193 tok := InspectToken{ 1194 ID: lval.id, 1195 MaybeID: s.lastAttemptedID, 1196 Str: lval.s, 1197 Start: lval.pos, 1198 End: int32(s.pos), 1199 Quoted: s.quoted, 1200 } 1201 1202 // A special affordance for unterminated quoted identifiers: try 1203 // to find the normalized text of the identifier found so far. 1204 if lval.id == lexbase.ERROR && s.lastAttemptedID == lexbase.IDENT && s.quoted { 1205 maybeIdent := sql[tok.Start:tok.End] + "\"" 1206 var si SQLScanner 1207 si.Init(maybeIdent) 1208 si.Scan(&lval) 1209 if lval.id == lexbase.IDENT { 1210 tok.Str = lval.s 1211 } 1212 } 1213 1214 tokens = append(tokens, tok) 1215 if lval.id == 0 || lval.id == lexbase.ERROR { 1216 return tokens 1217 } 1218 } 1219 }