github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/parser/scan.go (about) 1 // Copyright 2015 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package parser 12 13 import ( 14 "fmt" 15 "go/constant" 16 "go/token" 17 "strconv" 18 "strings" 19 "unicode/utf8" 20 "unsafe" 21 22 "github.com/cockroachdb/cockroach/pkg/sql/lex" 23 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 24 ) 25 26 const eof = -1 27 const errUnterminated = "unterminated string" 28 const errInvalidUTF8 = "invalid UTF-8 byte sequence" 29 const errInvalidHexNumeric = "invalid hexadecimal numeric literal" 30 const singleQuote = '\'' 31 const identQuote = '"' 32 33 // scanner lexes SQL statements. 34 type scanner struct { 35 in string 36 pos int 37 bytesPrealloc []byte 38 } 39 40 func makeScanner(str string) scanner { 41 var s scanner 42 s.init(str) 43 return s 44 } 45 46 func (s *scanner) init(str string) { 47 s.in = str 48 s.pos = 0 49 // Preallocate some buffer space for identifiers etc. 50 s.bytesPrealloc = make([]byte, len(str)) 51 } 52 53 // cleanup is used to avoid holding on to memory unnecessarily (for the cases 54 // where we reuse a scanner). 55 func (s *scanner) cleanup() { 56 s.bytesPrealloc = nil 57 } 58 59 func (s *scanner) allocBytes(length int) []byte { 60 if len(s.bytesPrealloc) >= length { 61 res := s.bytesPrealloc[:length:length] 62 s.bytesPrealloc = s.bytesPrealloc[length:] 63 return res 64 } 65 return make([]byte, length) 66 } 67 68 // buffer returns an empty []byte buffer that can be appended to. Any unused 69 // portion can be returned later using returnBuffer. 70 func (s *scanner) buffer() []byte { 71 buf := s.bytesPrealloc[:0] 72 s.bytesPrealloc = nil 73 return buf 74 } 75 76 // returnBuffer returns the unused portion of buf to the scanner, to be used for 77 // future allocBytes() or buffer() calls. The caller must not use buf again. 78 func (s *scanner) returnBuffer(buf []byte) { 79 if len(buf) < cap(buf) { 80 s.bytesPrealloc = buf[len(buf):] 81 } 82 } 83 84 // finishString casts the given buffer to a string and returns the unused 85 // portion of the buffer. The caller must not use buf again. 86 func (s *scanner) finishString(buf []byte) string { 87 str := *(*string)(unsafe.Pointer(&buf)) 88 s.returnBuffer(buf) 89 return str 90 } 91 92 func (s *scanner) scan(lval *sqlSymType) { 93 lval.id = 0 94 lval.pos = int32(s.pos) 95 lval.str = "EOF" 96 97 if _, ok := s.skipWhitespace(lval, true); !ok { 98 return 99 } 100 101 ch := s.next() 102 if ch == eof { 103 lval.pos = int32(s.pos) 104 return 105 } 106 107 lval.id = int32(ch) 108 lval.pos = int32(s.pos - 1) 109 lval.str = s.in[lval.pos:s.pos] 110 111 switch ch { 112 case '$': 113 // placeholder? $[0-9]+ 114 if lex.IsDigit(s.peek()) { 115 s.scanPlaceholder(lval) 116 return 117 } else if s.scanDollarQuotedString(lval) { 118 lval.id = SCONST 119 return 120 } 121 return 122 123 case identQuote: 124 // "[^"]" 125 if s.scanString(lval, identQuote, false /* allowEscapes */, true /* requireUTF8 */) { 126 lval.id = IDENT 127 } 128 return 129 130 case singleQuote: 131 // '[^']' 132 if s.scanString(lval, ch, false /* allowEscapes */, true /* requireUTF8 */) { 133 lval.id = SCONST 134 } 135 return 136 137 case 'b': 138 // Bytes? 139 if s.peek() == singleQuote { 140 // b'[^']' 141 s.pos++ 142 if s.scanString(lval, singleQuote, true /* allowEscapes */, false /* requireUTF8 */) { 143 lval.id = BCONST 144 } 145 return 146 } 147 s.scanIdent(lval) 148 return 149 150 case 'r', 'R': 151 s.scanIdent(lval) 152 return 153 154 case 'e', 'E': 155 // Escaped string? 156 if s.peek() == singleQuote { 157 // [eE]'[^']' 158 s.pos++ 159 if s.scanString(lval, singleQuote, true /* allowEscapes */, true /* requireUTF8 */) { 160 lval.id = SCONST 161 } 162 return 163 } 164 s.scanIdent(lval) 165 return 166 167 case 'B': 168 // Bit array literal? 169 if s.peek() == singleQuote { 170 // B'[01]*' 171 s.pos++ 172 s.scanBitString(lval, singleQuote) 173 return 174 } 175 s.scanIdent(lval) 176 return 177 178 case 'x', 'X': 179 // Hex literal? 180 if s.peek() == singleQuote { 181 // [xX]'[a-f0-9]' 182 s.pos++ 183 s.scanHexString(lval, singleQuote) 184 return 185 } 186 s.scanIdent(lval) 187 return 188 189 case '.': 190 switch t := s.peek(); { 191 case t == '.': // .. 192 s.pos++ 193 lval.id = DOT_DOT 194 return 195 case lex.IsDigit(t): 196 s.scanNumber(lval, ch) 197 return 198 } 199 return 200 201 case '!': 202 switch s.peek() { 203 case '=': // != 204 s.pos++ 205 lval.id = NOT_EQUALS 206 return 207 case '~': // !~ 208 s.pos++ 209 switch s.peek() { 210 case '*': // !~* 211 s.pos++ 212 lval.id = NOT_REGIMATCH 213 return 214 } 215 lval.id = NOT_REGMATCH 216 return 217 } 218 return 219 220 case '?': 221 switch s.peek() { 222 case '?': // ?? 223 s.pos++ 224 lval.id = HELPTOKEN 225 return 226 case '|': // ?| 227 s.pos++ 228 lval.id = JSON_SOME_EXISTS 229 return 230 case '&': // ?& 231 s.pos++ 232 lval.id = JSON_ALL_EXISTS 233 return 234 } 235 return 236 237 case '<': 238 switch s.peek() { 239 case '<': // << 240 s.pos++ 241 switch s.peek() { 242 case '=': // <<= 243 s.pos++ 244 lval.id = INET_CONTAINED_BY_OR_EQUALS 245 return 246 } 247 lval.id = LSHIFT 248 return 249 case '>': // <> 250 s.pos++ 251 lval.id = NOT_EQUALS 252 return 253 case '=': // <= 254 s.pos++ 255 lval.id = LESS_EQUALS 256 return 257 case '@': // <@ 258 s.pos++ 259 lval.id = CONTAINED_BY 260 return 261 } 262 return 263 264 case '>': 265 switch s.peek() { 266 case '>': // >> 267 s.pos++ 268 switch s.peek() { 269 case '=': // >>= 270 s.pos++ 271 lval.id = INET_CONTAINS_OR_EQUALS 272 return 273 } 274 lval.id = RSHIFT 275 return 276 case '=': // >= 277 s.pos++ 278 lval.id = GREATER_EQUALS 279 return 280 } 281 return 282 283 case ':': 284 switch s.peek() { 285 case ':': // :: 286 if s.peekN(1) == ':' { 287 // ::: 288 s.pos += 2 289 lval.id = TYPEANNOTATE 290 return 291 } 292 s.pos++ 293 lval.id = TYPECAST 294 return 295 } 296 return 297 298 case '|': 299 switch s.peek() { 300 case '|': // || 301 s.pos++ 302 switch s.peek() { 303 case '/': // ||/ 304 s.pos++ 305 lval.id = CBRT 306 return 307 } 308 lval.id = CONCAT 309 return 310 case '/': // |/ 311 s.pos++ 312 lval.id = SQRT 313 return 314 } 315 return 316 317 case '/': 318 switch s.peek() { 319 case '/': // // 320 s.pos++ 321 lval.id = FLOORDIV 322 return 323 } 324 return 325 326 case '~': 327 switch s.peek() { 328 case '*': // ~* 329 s.pos++ 330 lval.id = REGIMATCH 331 return 332 } 333 return 334 335 case '@': 336 switch s.peek() { 337 case '>': // @> 338 s.pos++ 339 lval.id = CONTAINS 340 return 341 } 342 return 343 344 case '&': 345 switch s.peek() { 346 case '&': // && 347 s.pos++ 348 lval.id = AND_AND 349 return 350 } 351 return 352 353 case '-': 354 switch s.peek() { 355 case '>': // -> 356 if s.peekN(1) == '>' { 357 // ->> 358 s.pos += 2 359 lval.id = FETCHTEXT 360 return 361 } 362 s.pos++ 363 lval.id = FETCHVAL 364 return 365 } 366 return 367 368 case '#': 369 switch s.peek() { 370 case '>': // #> 371 if s.peekN(1) == '>' { 372 // #>> 373 s.pos += 2 374 lval.id = FETCHTEXT_PATH 375 return 376 } 377 s.pos++ 378 lval.id = FETCHVAL_PATH 379 return 380 case '-': // #- 381 s.pos++ 382 lval.id = REMOVE_PATH 383 return 384 } 385 return 386 387 default: 388 if lex.IsDigit(ch) { 389 s.scanNumber(lval, ch) 390 return 391 } 392 if lex.IsIdentStart(ch) { 393 s.scanIdent(lval) 394 return 395 } 396 } 397 398 // Everything else is a single character token which we already initialized 399 // lval for above. 400 } 401 402 func (s *scanner) peek() int { 403 if s.pos >= len(s.in) { 404 return eof 405 } 406 return int(s.in[s.pos]) 407 } 408 409 func (s *scanner) peekN(n int) int { 410 pos := s.pos + n 411 if pos >= len(s.in) { 412 return eof 413 } 414 return int(s.in[pos]) 415 } 416 417 func (s *scanner) next() int { 418 ch := s.peek() 419 if ch != eof { 420 s.pos++ 421 } 422 return ch 423 } 424 425 func (s *scanner) skipWhitespace(lval *sqlSymType, allowComments bool) (newline, ok bool) { 426 newline = false 427 for { 428 ch := s.peek() 429 if ch == '\n' { 430 s.pos++ 431 newline = true 432 continue 433 } 434 if ch == ' ' || ch == '\t' || ch == '\r' || ch == '\f' { 435 s.pos++ 436 continue 437 } 438 if allowComments { 439 if present, cok := s.scanComment(lval); !cok { 440 return false, false 441 } else if present { 442 continue 443 } 444 } 445 break 446 } 447 return newline, true 448 } 449 450 func (s *scanner) scanComment(lval *sqlSymType) (present, ok bool) { 451 start := s.pos 452 ch := s.peek() 453 454 if ch == '/' { 455 s.pos++ 456 if s.peek() != '*' { 457 s.pos-- 458 return false, true 459 } 460 s.pos++ 461 depth := 1 462 for { 463 switch s.next() { 464 case '*': 465 if s.peek() == '/' { 466 s.pos++ 467 depth-- 468 if depth == 0 { 469 return true, true 470 } 471 continue 472 } 473 474 case '/': 475 if s.peek() == '*' { 476 s.pos++ 477 depth++ 478 continue 479 } 480 481 case eof: 482 lval.id = ERROR 483 lval.pos = int32(start) 484 lval.str = "unterminated comment" 485 return false, false 486 } 487 } 488 } 489 490 if ch == '-' { 491 s.pos++ 492 if s.peek() != '-' { 493 s.pos-- 494 return false, true 495 } 496 for { 497 switch s.next() { 498 case eof, '\n': 499 return true, true 500 } 501 } 502 } 503 504 return false, true 505 } 506 507 func (s *scanner) scanIdent(lval *sqlSymType) { 508 s.pos-- 509 start := s.pos 510 isASCII := true 511 isLower := true 512 513 // Consume the scanner character by character, stopping after the last legal 514 // identifier character. By the end of this function, we need to 515 // lowercase and unicode normalize this identifier, which is expensive if 516 // there are actual unicode characters in it. If not, it's quite cheap - and 517 // if it's lowercase already, there's no work to do. Therefore, we keep track 518 // of whether the string is only ASCII or only ASCII lowercase for later. 519 for { 520 ch := s.peek() 521 //fmt.Println(ch, ch >= utf8.RuneSelf, ch >= 'A' && ch <= 'Z') 522 523 if ch >= utf8.RuneSelf { 524 isASCII = false 525 } else if ch >= 'A' && ch <= 'Z' { 526 isLower = false 527 } 528 529 if !lex.IsIdentMiddle(ch) { 530 break 531 } 532 533 s.pos++ 534 } 535 //fmt.Println("parsed: ", s.in[start:s.pos], isASCII, isLower) 536 537 if isLower { 538 // Already lowercased - nothing to do. 539 lval.str = s.in[start:s.pos] 540 } else if isASCII { 541 // We know that the identifier we've seen so far is ASCII, so we don't need 542 // to unicode normalize. Instead, just lowercase as normal. 543 b := s.allocBytes(s.pos - start) 544 _ = b[s.pos-start-1] // For bounds check elimination. 545 for i, c := range s.in[start:s.pos] { 546 if c >= 'A' && c <= 'Z' { 547 c += 'a' - 'A' 548 } 549 b[i] = byte(c) 550 } 551 lval.str = *(*string)(unsafe.Pointer(&b)) 552 } else { 553 // The string has unicode in it. No choice but to run Normalize. 554 lval.str = lex.NormalizeName(s.in[start:s.pos]) 555 } 556 557 isExperimental := false 558 kw := lval.str 559 switch { 560 case strings.HasPrefix(lval.str, "experimental_"): 561 kw = lval.str[13:] 562 isExperimental = true 563 case strings.HasPrefix(lval.str, "testing_"): 564 kw = lval.str[8:] 565 isExperimental = true 566 } 567 lval.id = lex.GetKeywordID(kw) 568 if lval.id != lex.IDENT { 569 if isExperimental { 570 if _, ok := lex.AllowedExperimental[kw]; !ok { 571 // If the parsed token is not on the whitelisted set of keywords, 572 // then it might have been intended to be parsed as something else. 573 // In that case, re-tokenize the original string. 574 lval.id = lex.GetKeywordID(lval.str) 575 } else { 576 // It is a whitelisted keyword, so remember the shortened 577 // keyword for further processing. 578 lval.str = kw 579 } 580 } 581 } else { 582 // If the word after experimental_ or testing_ is an identifier, 583 // then we might have classified it incorrectly after removing the 584 // experimental_/testing_ prefix. 585 lval.id = lex.GetKeywordID(lval.str) 586 } 587 } 588 589 func (s *scanner) scanNumber(lval *sqlSymType, ch int) { 590 start := s.pos - 1 591 isHex := false 592 hasDecimal := ch == '.' 593 hasExponent := false 594 595 for { 596 ch := s.peek() 597 if (isHex && lex.IsHexDigit(ch)) || lex.IsDigit(ch) { 598 s.pos++ 599 continue 600 } 601 if ch == 'x' || ch == 'X' { 602 if isHex || s.in[start] != '0' || s.pos != start+1 { 603 lval.id = ERROR 604 lval.str = errInvalidHexNumeric 605 return 606 } 607 s.pos++ 608 isHex = true 609 continue 610 } 611 if isHex { 612 break 613 } 614 if ch == '.' { 615 if hasDecimal || hasExponent { 616 break 617 } 618 s.pos++ 619 if s.peek() == '.' { 620 // Found ".." while scanning a number: back up to the end of the 621 // integer. 622 s.pos-- 623 break 624 } 625 hasDecimal = true 626 continue 627 } 628 if ch == 'e' || ch == 'E' { 629 if hasExponent { 630 break 631 } 632 hasExponent = true 633 s.pos++ 634 ch = s.peek() 635 if ch == '-' || ch == '+' { 636 s.pos++ 637 } 638 ch = s.peek() 639 if !lex.IsDigit(ch) { 640 lval.id = ERROR 641 lval.str = "invalid floating point literal" 642 return 643 } 644 continue 645 } 646 break 647 } 648 649 lval.str = s.in[start:s.pos] 650 if hasDecimal || hasExponent { 651 lval.id = FCONST 652 floatConst := constant.MakeFromLiteral(lval.str, token.FLOAT, 0) 653 if floatConst.Kind() == constant.Unknown { 654 lval.id = ERROR 655 lval.str = fmt.Sprintf("could not make constant float from literal %q", lval.str) 656 return 657 } 658 lval.union.val = tree.NewNumVal(floatConst, lval.str, false /* negative */) 659 } else { 660 if isHex && s.pos == start+2 { 661 lval.id = ERROR 662 lval.str = errInvalidHexNumeric 663 return 664 } 665 666 // Strip off leading zeros from non-hex (decimal) literals so that 667 // constant.MakeFromLiteral doesn't inappropriately interpret the 668 // string as an octal literal. Note: we can't use strings.TrimLeft 669 // here, because it will truncate '0' to ''. 670 if !isHex { 671 for len(lval.str) > 1 && lval.str[0] == '0' { 672 lval.str = lval.str[1:] 673 } 674 } 675 676 lval.id = ICONST 677 intConst := constant.MakeFromLiteral(lval.str, token.INT, 0) 678 if intConst.Kind() == constant.Unknown { 679 lval.id = ERROR 680 lval.str = fmt.Sprintf("could not make constant int from literal %q", lval.str) 681 return 682 } 683 lval.union.val = tree.NewNumVal(intConst, lval.str, false /* negative */) 684 } 685 } 686 687 func (s *scanner) scanPlaceholder(lval *sqlSymType) { 688 start := s.pos 689 for lex.IsDigit(s.peek()) { 690 s.pos++ 691 } 692 lval.str = s.in[start:s.pos] 693 694 placeholder, err := tree.NewPlaceholder(lval.str) 695 if err != nil { 696 lval.id = ERROR 697 lval.str = err.Error() 698 return 699 } 700 lval.id = PLACEHOLDER 701 lval.union.val = placeholder 702 } 703 704 // scanHexString scans the content inside x'....'. 705 func (s *scanner) scanHexString(lval *sqlSymType, ch int) bool { 706 buf := s.buffer() 707 708 var curbyte byte 709 bytep := 0 710 const errInvalidBytesLiteral = "invalid hexadecimal bytes literal" 711 outer: 712 for { 713 b := s.next() 714 switch b { 715 case ch: 716 newline, ok := s.skipWhitespace(lval, false) 717 if !ok { 718 return false 719 } 720 // SQL allows joining adjacent strings separated by whitespace 721 // as long as that whitespace contains at least one 722 // newline. Kind of strange to require the newline, but that 723 // is the standard. 724 if s.peek() == ch && newline { 725 s.pos++ 726 continue 727 } 728 break outer 729 730 case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': 731 curbyte = (curbyte << 4) | byte(b-'0') 732 case 'a', 'b', 'c', 'd', 'e', 'f': 733 curbyte = (curbyte << 4) | byte(b-'a'+10) 734 case 'A', 'B', 'C', 'D', 'E', 'F': 735 curbyte = (curbyte << 4) | byte(b-'A'+10) 736 default: 737 lval.id = ERROR 738 lval.str = errInvalidBytesLiteral 739 return false 740 } 741 bytep++ 742 743 if bytep > 1 { 744 buf = append(buf, curbyte) 745 bytep = 0 746 curbyte = 0 747 } 748 } 749 750 if bytep != 0 { 751 lval.id = ERROR 752 lval.str = errInvalidBytesLiteral 753 return false 754 } 755 756 lval.id = BCONST 757 lval.str = s.finishString(buf) 758 return true 759 } 760 761 // scanBitString scans the content inside B'....'. 762 func (s *scanner) scanBitString(lval *sqlSymType, ch int) bool { 763 buf := s.buffer() 764 outer: 765 for { 766 b := s.next() 767 switch b { 768 case ch: 769 newline, ok := s.skipWhitespace(lval, false) 770 if !ok { 771 return false 772 } 773 // SQL allows joining adjacent strings separated by whitespace 774 // as long as that whitespace contains at least one 775 // newline. Kind of strange to require the newline, but that 776 // is the standard. 777 if s.peek() == ch && newline { 778 s.pos++ 779 continue 780 } 781 break outer 782 783 case '0', '1': 784 buf = append(buf, byte(b)) 785 default: 786 lval.id = ERROR 787 lval.str = fmt.Sprintf(`"%c" is not a valid binary digit`, rune(b)) 788 return false 789 } 790 } 791 792 lval.id = BITCONST 793 lval.str = s.finishString(buf) 794 return true 795 } 796 797 // scanString scans the content inside '...'. This is used for simple 798 // string literals '...' but also e'....' and b'...'. For x'...', see 799 // scanHexString(). 800 func (s *scanner) scanString(lval *sqlSymType, ch int, allowEscapes, requireUTF8 bool) bool { 801 buf := s.buffer() 802 var runeTmp [utf8.UTFMax]byte 803 start := s.pos 804 805 outer: 806 for { 807 switch s.next() { 808 case ch: 809 buf = append(buf, s.in[start:s.pos-1]...) 810 if s.peek() == ch { 811 // Double quote is translated into a single quote that is part of the 812 // string. 813 start = s.pos 814 s.pos++ 815 continue 816 } 817 818 newline, ok := s.skipWhitespace(lval, false) 819 if !ok { 820 return false 821 } 822 // SQL allows joining adjacent strings separated by whitespace 823 // as long as that whitespace contains at least one 824 // newline. Kind of strange to require the newline, but that 825 // is the standard. 826 if s.peek() == ch && newline { 827 s.pos++ 828 start = s.pos 829 continue 830 } 831 break outer 832 833 case '\\': 834 t := s.peek() 835 836 if allowEscapes { 837 buf = append(buf, s.in[start:s.pos-1]...) 838 if t == ch { 839 start = s.pos 840 s.pos++ 841 continue 842 } 843 844 switch t { 845 case 'a', 'b', 'f', 'n', 'r', 't', 'v', 'x', 'X', 'u', 'U', '\\', 846 '0', '1', '2', '3', '4', '5', '6', '7': 847 var tmp string 848 if t == 'X' && len(s.in[s.pos:]) >= 3 { 849 // UnquoteChar doesn't handle 'X' so we create a temporary string 850 // for it to parse. 851 tmp = "\\x" + s.in[s.pos+1:s.pos+3] 852 } else { 853 tmp = s.in[s.pos-1:] 854 } 855 v, multibyte, tail, err := strconv.UnquoteChar(tmp, byte(ch)) 856 if err != nil { 857 lval.id = ERROR 858 lval.str = err.Error() 859 return false 860 } 861 if v < utf8.RuneSelf || !multibyte { 862 buf = append(buf, byte(v)) 863 } else { 864 n := utf8.EncodeRune(runeTmp[:], v) 865 buf = append(buf, runeTmp[:n]...) 866 } 867 s.pos += len(tmp) - len(tail) - 1 868 start = s.pos 869 continue 870 } 871 872 // If we end up here, it's a redundant escape - simply drop the 873 // backslash. For example, e'\"' is equivalent to e'"', and 874 // e'\d\b' to e'd\b'. This is what Postgres does: 875 // http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE 876 start = s.pos 877 } 878 879 case eof: 880 lval.id = ERROR 881 lval.str = errUnterminated 882 return false 883 } 884 } 885 886 if requireUTF8 && !utf8.Valid(buf) { 887 lval.id = ERROR 888 lval.str = errInvalidUTF8 889 return false 890 } 891 892 lval.str = s.finishString(buf) 893 return true 894 } 895 896 // scanDollarQuotedString scans for so called dollar-quoted strings, which start/end with either $$ or $tag$, where 897 // tag is some arbitrary string. e.g. $$a string$$ or $escaped$a string$escaped$. 898 func (s *scanner) scanDollarQuotedString(lval *sqlSymType) bool { 899 buf := s.buffer() 900 start := s.pos 901 902 foundStartTag := false 903 possibleEndTag := false 904 startTagIndex := -1 905 var startTag string 906 907 outer: 908 for { 909 ch := s.peek() 910 switch ch { 911 case '$': 912 s.pos++ 913 if foundStartTag { 914 if possibleEndTag { 915 if len(startTag) == startTagIndex { 916 // Found end tag. 917 buf = append(buf, s.in[start+len(startTag)+1:s.pos-len(startTag)-2]...) 918 break outer 919 } else { 920 // Was not the end tag but the current $ might be the start of the end tag we are looking for, so 921 // just reset the startTagIndex. 922 startTagIndex = 0 923 } 924 } else { 925 possibleEndTag = true 926 startTagIndex = 0 927 } 928 } else { 929 startTag = s.in[start : s.pos-1] 930 foundStartTag = true 931 } 932 933 case eof: 934 if foundStartTag { 935 // A start tag was found, therefore we expect an end tag before the eof, otherwise it is an error. 936 lval.id = ERROR 937 lval.str = errUnterminated 938 } else { 939 // This is not a dollar-quoted string, reset the pos back to the start. 940 s.pos = start 941 } 942 return false 943 944 default: 945 // If we haven't found a start tag yet, check whether the current characters is a valid for a tag. 946 if !foundStartTag && !lex.IsIdentStart(ch) { 947 return false 948 } 949 s.pos++ 950 if possibleEndTag { 951 // Check whether this could be the end tag. 952 if startTagIndex >= len(startTag) || ch != int(startTag[startTagIndex]) { 953 // This is not the end tag we are looking for. 954 possibleEndTag = false 955 startTagIndex = -1 956 } else { 957 startTagIndex++ 958 } 959 } 960 } 961 } 962 963 if !utf8.Valid(buf) { 964 lval.id = ERROR 965 lval.str = errInvalidUTF8 966 return false 967 } 968 969 lval.str = s.finishString(buf) 970 return true 971 } 972 973 // SplitFirstStatement returns the length of the prefix of the string up to and 974 // including the first semicolon that separates statements. If there is no 975 // semicolon, returns ok=false. 976 func SplitFirstStatement(sql string) (pos int, ok bool) { 977 s := makeScanner(sql) 978 var lval sqlSymType 979 for { 980 s.scan(&lval) 981 switch lval.id { 982 case 0, ERROR: 983 return 0, false 984 case ';': 985 return s.pos, true 986 } 987 } 988 } 989 990 // Tokens decomposes the input into lexical tokens. 991 func Tokens(sql string) (tokens []TokenString, ok bool) { 992 s := makeScanner(sql) 993 for { 994 var lval sqlSymType 995 s.scan(&lval) 996 if lval.id == ERROR { 997 return nil, false 998 } 999 if lval.id == 0 { 1000 break 1001 } 1002 tokens = append(tokens, TokenString{TokenID: lval.id, Str: lval.str}) 1003 } 1004 return tokens, true 1005 } 1006 1007 // TokenString is the unit value returned by Tokens. 1008 type TokenString struct { 1009 TokenID int32 1010 Str string 1011 } 1012 1013 // LastLexicalToken returns the last lexical token. If the string has no lexical 1014 // tokens, returns 0 and ok=false. 1015 func LastLexicalToken(sql string) (lastTok int, ok bool) { 1016 s := makeScanner(sql) 1017 var lval sqlSymType 1018 for { 1019 last := lval.id 1020 s.scan(&lval) 1021 if lval.id == 0 { 1022 return int(last), last != 0 1023 } 1024 } 1025 } 1026 1027 // EndsInSemicolon returns true if the last lexical token is a semicolon. 1028 func EndsInSemicolon(sql string) bool { 1029 lastTok, ok := LastLexicalToken(sql) 1030 return ok && lastTok == ';' 1031 }