github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/parser/lexer.go (about) 1 // Copyright 2022 zGraph Authors. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package parser 16 17 import ( 18 "bytes" 19 "encoding/hex" 20 "fmt" 21 "math" 22 "strconv" 23 "strings" 24 "unicode" 25 26 "github.com/vescale/zgraph/datum" 27 ) 28 29 var _ = yyLexer(&Lexer{}) 30 31 // Pos represents the position of a token. 32 type Pos struct { 33 Line int 34 Col int 35 Offset int 36 } 37 38 // Lexer implements the yyLexer interface. 39 type Lexer struct { 40 r reader 41 buf bytes.Buffer 42 43 errs []error 44 warns []error 45 stmtStartPos int 46 47 // inBangComment is true if we are inside a `/*! ... */` block. 48 // It is used to ignore a stray `*/` when scanning. 49 inBangComment bool 50 51 // Whether record the original text keyword position to the AST node. 52 skipPositionRecording bool 53 54 // lastScanOffset indicates last offset returned by scan(). 55 // It's used to substring sql in syntax error message. 56 lastScanOffset int 57 58 // lastKeyword records the previous keyword returned by scan(). 59 // determine whether an optimizer hint should be parsed or ignored. 60 lastKeyword int 61 // lastKeyword2 records the keyword before lastKeyword, it is used 62 // to disambiguate hint after for update, which should be ignored. 63 lastKeyword2 int 64 // lastKeyword3 records the keyword before lastKeyword2, it is used 65 // to disambiguate hint after create binding for update, which should 66 // be pertained. 67 lastKeyword3 int 68 69 // hintPos records the start position of the previous optimizer hint. 70 lastHintPos Pos 71 72 // true if a dot follows an identifier 73 identifierDot bool 74 } 75 76 // Errors returns the errors and warns during a scan. 77 func (l *Lexer) Errors() (warns []error, errs []error) { 78 return l.warns, l.errs 79 } 80 81 // reset resets the sql string to be scanned. 82 func (l *Lexer) reset(sql string) { 83 l.r = reader{s: sql, p: Pos{Line: 1}, l: len(sql)} 84 l.buf.Reset() 85 l.errs = l.errs[:0] 86 l.warns = l.warns[:0] 87 l.stmtStartPos = 0 88 l.inBangComment = false 89 l.lastKeyword = 0 90 } 91 92 func (l *Lexer) stmtText() string { 93 endPos := l.r.pos().Offset 94 if l.r.s[endPos-1] == '\n' { 95 endPos = endPos - 1 // trim new line 96 } 97 if l.r.s[l.stmtStartPos] == '\n' { 98 l.stmtStartPos++ 99 } 100 101 text := l.r.s[l.stmtStartPos:endPos] 102 103 l.stmtStartPos = endPos 104 return text 105 } 106 107 // Errorf tells scanner something is wrong. 108 // Lexer satisfies yyLexer interface which need this function. 109 func (l *Lexer) Errorf(format string, a ...interface{}) (err error) { 110 str := fmt.Sprintf(format, a...) 111 val := l.r.s[l.lastScanOffset:] 112 var lenStr = "" 113 if len(val) > 2048 { 114 lenStr = "(total length " + strconv.Itoa(len(val)) + ")" 115 val = val[:2048] 116 } 117 err = fmt.Errorf("line %d column %d near \"%s\"%s %s", 118 l.r.p.Line, l.r.p.Col, val, str, lenStr) 119 return 120 } 121 122 // AppendError sets error into scanner. 123 // Lexer satisfies yyLexer interface which need this function. 124 func (l *Lexer) AppendError(err error) { 125 if err == nil { 126 return 127 } 128 l.errs = append(l.errs, err) 129 } 130 131 // AppendWarn sets warning into scanner. 132 func (l *Lexer) AppendWarn(err error) { 133 if err == nil { 134 return 135 } 136 l.warns = append(l.warns, err) 137 } 138 139 // Lex returns a token and store the token value in v. 140 // Lexer satisfies yyLexer interface. 141 // 0 and invalid are special token id this function would return: 142 // return 0 tells parser that scanner meets EOF, 143 // return invalid tells parser that scanner meets illegal character. 144 func (l *Lexer) Lex(v *yySymType) int { 145 tok, pos, lit := l.scan() 146 l.lastScanOffset = pos.Offset 147 l.lastKeyword3 = l.lastKeyword2 148 l.lastKeyword2 = l.lastKeyword 149 l.lastKeyword = 0 150 v.offset = pos.Offset 151 v.ident = lit 152 if tok == identifier { 153 if tok1 := l.isTokenIdentifier(lit, pos.Offset); tok1 != 0 { 154 tok = tok1 155 l.lastKeyword = tok1 156 } 157 } 158 159 switch tok { 160 case intLit: 161 return toInt(l, v, lit) 162 case floatLit: 163 return toFloat(l, v, lit) 164 case decLit: 165 return toDecimal(l, v, lit) 166 case hexLit: 167 return toHex(l, v, lit) 168 case singleAtIdentifier, doubleAtIdentifier, cast, extract: 169 v.item = lit 170 return tok 171 case null: 172 v.item = nil 173 case quotedIdentifier, identifier: 174 tok = identifier 175 l.identifierDot = l.r.peek() == '.' 176 v.ident = lit 177 case stringLit: 178 v.ident = lit 179 } 180 181 return tok 182 } 183 184 func toInt(l yyLexer, lval *yySymType, str string) int { 185 n, err := strconv.ParseUint(str, 10, 64) 186 if err != nil { 187 e := err.(*strconv.NumError) 188 if e.Err == strconv.ErrRange { 189 // TODO: toDecimal maybe out of range still. 190 // This kind of error should be throw to higher level, because truncated data maybe legal. 191 // For example, this SQL returns error: 192 // create table test (id decimal(30, 0)); 193 // insert into test values(123456789012345678901234567890123094839045793405723406801943850); 194 // While this SQL: 195 // select 1234567890123456789012345678901230948390457934057234068019438509023041874359081325875128590860234789847359871045943057; 196 // get value 99999999999999999999999999999999999999999999999999999999999999999 197 return toDecimal(l, lval, str) 198 } 199 l.AppendError(fmt.Errorf("integer literal: %v", err)) 200 return invalid 201 } 202 203 switch { 204 case n <= math.MaxInt64: 205 lval.item = int64(n) 206 default: 207 lval.item = n 208 } 209 return intLit 210 } 211 212 func toDecimal(l yyLexer, lval *yySymType, str string) int { 213 dec, err := datum.ParseDecimal(str) 214 if err != nil { 215 l.AppendError(fmt.Errorf("decimal literal: %v", err)) 216 return invalid 217 } 218 lval.item = dec 219 return decLit 220 } 221 222 func toFloat(l yyLexer, lval *yySymType, str string) int { 223 n, err := strconv.ParseFloat(str, 64) 224 if err != nil { 225 l.AppendError(l.Errorf("float literal: %v", err)) 226 return invalid 227 } 228 229 lval.item = n 230 return floatLit 231 } 232 233 func toHex(l yyLexer, lval *yySymType, str string) int { 234 str = strings.TrimPrefix(str, "0x") 235 buf, err := hex.DecodeString(str) 236 if err != nil { 237 l.AppendError(l.Errorf("hex literal: %v", err)) 238 return invalid 239 } 240 lval.item = datum.NewBytes(buf) 241 return hexLit 242 } 243 244 // LexLiteral returns the value of the converted literal 245 func (l *Lexer) LexLiteral() interface{} { 246 symType := &yySymType{} 247 l.Lex(symType) 248 if symType.item == nil { 249 return symType.ident 250 } 251 return symType.item 252 } 253 254 // InheritScanner returns a new scanner object which inherits configurations from the parent scanner. 255 func (l *Lexer) InheritScanner(sql string) *Lexer { 256 return &Lexer{ 257 r: reader{s: sql}, 258 } 259 } 260 261 // NewLexer returns a new scanner object. 262 func NewLexer(s string) *Lexer { 263 lexer := &Lexer{r: reader{s: s}} 264 lexer.reset(s) 265 return lexer 266 } 267 268 func (l *Lexer) skipWhitespace() byte { 269 return l.r.incAsLongAs(func(b byte) bool { 270 return unicode.IsSpace(rune(b)) 271 }) 272 } 273 274 func (l *Lexer) scan() (tok int, pos Pos, lit string) { 275 ch0 := l.r.peek() 276 if unicode.IsSpace(rune(ch0)) { 277 ch0 = l.skipWhitespace() 278 } 279 pos = l.r.pos() 280 if l.r.eof() { 281 // when scanner meets EOF, the returned token should be 0, 282 // because 0 is a special token id to remind the parser that stream is end. 283 return 0, pos, "" 284 } 285 286 if isIdentExtend(ch0) { 287 return scanIdentifier(l) 288 } 289 290 // search a trie to get a token. 291 node := &ruleTable 292 for !(node.childs[ch0] == nil || l.r.eof()) { 293 node = node.childs[ch0] 294 if node.fn != nil { 295 return node.fn(l) 296 } 297 l.r.inc() 298 ch0 = l.r.peek() 299 } 300 301 tok, lit = node.token, l.r.data(&pos) 302 return 303 } 304 305 func startWithSharp(s *Lexer) (tok int, pos Pos, lit string) { 306 s.r.incAsLongAs(func(ch byte) bool { 307 return ch != '\n' 308 }) 309 return s.scan() 310 } 311 312 func startWithSlash(s *Lexer) (tok int, pos Pos, lit string) { 313 pos = s.r.pos() 314 s.r.inc() 315 if ch := s.r.peek(); ch != '*' { 316 if ch != '-' { 317 tok = int('/') 318 lit = "/" 319 return 320 } 321 s.r.inc() 322 if ch = s.r.peek(); ch == '>' { 323 tok = reachOutgoingRight 324 s.r.inc() 325 } else { 326 tok = reachIncomingRight 327 } 328 return 329 } 330 331 currentCharIsStar := false 332 333 s.r.inc() // we see '/*' so far. 334 switch s.r.readByte() { 335 case '!': // '/*!' MySQL-specific comments 336 // See http://dev.mysql.com/doc/refman/5.7/en/comments.html 337 // in '/*!', which we always recognize regardless of version. 338 s.scanVersionDigits(5, 5) 339 s.inBangComment = true 340 return s.scan() 341 342 case 'M': // '/*M' maybe MariaDB-specific comments 343 // no special treatment for now. 344 break 345 346 case '*': // '/**' if the next char is '/' it would close the comment. 347 currentCharIsStar = true 348 349 default: 350 break 351 } 352 353 // standard C-like comment. read until we see '*/' then drop it. 354 for { 355 if currentCharIsStar || s.r.incAsLongAs(func(ch byte) bool { return ch != '*' }) == '*' { 356 switch s.r.readByte() { 357 case '/': 358 return s.scan() 359 case '*': 360 currentCharIsStar = true 361 continue 362 default: 363 currentCharIsStar = false 364 continue 365 } 366 } 367 // unclosed comment or other errors. 368 s.errs = append(s.errs, parseErrorWith(s.r.data(&pos), s.r.p.Line)) 369 return 370 } 371 } 372 373 const errTextLength = 80 374 375 // parseErrorWith returns "You have a syntax error near..." error message compatible with mysql. 376 func parseErrorWith(errstr string, lineno int) error { 377 if len(errstr) > errTextLength { 378 errstr = errstr[:errTextLength] 379 } 380 return fmt.Errorf("near '%-.80s' at line %d", errstr, lineno) 381 } 382 383 func startWithStar(s *Lexer) (tok int, pos Pos, lit string) { 384 pos = s.r.pos() 385 s.r.inc() 386 387 // skip and exit '/*!' if we see '*/' 388 if s.inBangComment && s.r.peek() == '/' { 389 s.inBangComment = false 390 s.r.inc() 391 return s.scan() 392 } 393 // otherwise it is just a normal star. 394 s.identifierDot = false 395 return '*', pos, "*" 396 } 397 398 func startWithAt(s *Lexer) (tok int, pos Pos, lit string) { 399 pos = s.r.pos() 400 s.r.inc() 401 402 tok, lit = scanIdentifierOrString(s) 403 switch tok { 404 case '@': 405 s.r.inc() 406 stream := s.r.s[pos.Offset+2:] 407 var prefix string 408 for _, v := range []string{"global.", "session.", "local."} { 409 if len(v) > len(stream) { 410 continue 411 } 412 if strings.EqualFold(stream[:len(v)], v) { 413 prefix = v 414 s.r.incN(len(v)) 415 break 416 } 417 } 418 tok, lit = scanIdentifierOrString(s) 419 switch tok { 420 case stringLit, quotedIdentifier: 421 tok, lit = doubleAtIdentifier, "@@"+prefix+lit 422 case identifier: 423 tok, lit = doubleAtIdentifier, s.r.data(&pos) 424 } 425 case invalid: 426 return 427 default: 428 tok = singleAtIdentifier 429 } 430 431 return 432 } 433 434 func scanIdentifier(s *Lexer) (int, Pos, string) { 435 pos := s.r.pos() 436 s.r.incAsLongAs(isIdentChar) 437 return identifier, pos, s.r.data(&pos) 438 } 439 440 func scanIdentifierOrString(s *Lexer) (tok int, lit string) { 441 ch1 := s.r.peek() 442 switch ch1 { 443 case '\'', '"': 444 tok, _, lit = startString(s) 445 case '`': 446 tok, _, lit = scanQuotedIdent(s) 447 default: 448 if isUserVarChar(ch1) { 449 pos := s.r.pos() 450 s.r.incAsLongAs(isUserVarChar) 451 tok, lit = identifier, s.r.data(&pos) 452 } else { 453 tok = int(ch1) 454 } 455 } 456 return 457 } 458 459 var ( 460 quotedIdentifier = -identifier 461 ) 462 463 func scanQuotedIdent(s *Lexer) (tok int, pos Pos, lit string) { 464 pos = s.r.pos() 465 s.r.inc() 466 s.buf.Reset() 467 for !s.r.eof() { 468 ch := s.r.readByte() 469 if ch == '`' { 470 if s.r.peek() != '`' { 471 // don't return identifier in case that it's interpreted as keyword token later. 472 tok, lit = quotedIdentifier, s.buf.String() 473 return 474 } 475 s.r.inc() 476 } 477 s.buf.WriteByte(ch) 478 } 479 tok = invalid 480 return 481 } 482 483 func startString(s *Lexer) (tok int, pos Pos, lit string) { 484 return s.scanString() 485 } 486 487 func (l *Lexer) scanString() (tok int, pos Pos, lit string) { 488 tok, pos = stringLit, l.r.pos() 489 ending := l.r.readByte() 490 l.buf.Reset() 491 for !l.r.eof() { 492 ch0 := l.r.readByte() 493 if ch0 == ending { 494 if l.r.peek() != ending { 495 lit = l.buf.String() 496 return 497 } 498 l.r.inc() 499 l.buf.WriteByte(ch0) 500 } else if ch0 == '\\' { 501 if l.r.eof() { 502 break 503 } 504 l.handleEscape(l.r.peek(), &l.buf) 505 l.r.inc() 506 } else { 507 l.buf.WriteByte(ch0) 508 } 509 } 510 511 tok = invalid 512 return 513 } 514 515 // handleEscape handles the case in scanString when previous char is '\'. 516 func (*Lexer) handleEscape(b byte, buf *bytes.Buffer) { 517 var ch0 byte 518 /* 519 \" \' \\ \n \0 \b \Z \r \t ==> escape to one char 520 \% \_ ==> preserve both char 521 other ==> remove \ 522 */ 523 switch b { 524 case 'n': 525 ch0 = '\n' 526 case '0': 527 ch0 = 0 528 case 'b': 529 ch0 = 8 530 case 'Z': 531 ch0 = 26 532 case 'r': 533 ch0 = '\r' 534 case 't': 535 ch0 = '\t' 536 case '%', '_': 537 buf.WriteByte('\\') 538 ch0 = b 539 default: 540 ch0 = b 541 } 542 buf.WriteByte(ch0) 543 } 544 545 func startWithNumber(s *Lexer) (tok int, pos Pos, lit string) { 546 if s.identifierDot { 547 return scanIdentifier(s) 548 } 549 pos = s.r.pos() 550 tok = intLit 551 ch0 := s.r.readByte() 552 if ch0 == '0' { 553 tok = intLit 554 ch1 := s.r.peek() 555 switch { 556 case ch1 >= '0' && ch1 <= '7': 557 s.r.inc() 558 s.scanOct() 559 case ch1 == 'x' || ch1 == 'X': 560 s.r.inc() 561 p1 := s.r.pos() 562 s.scanHex() 563 p2 := s.r.pos() 564 // 0x, 0x7fz3 are identifier 565 if p1 == p2 || isDigit(s.r.peek()) { 566 s.r.incAsLongAs(isIdentChar) 567 return identifier, pos, s.r.data(&pos) 568 } 569 tok = hexLit 570 case ch1 == '.': 571 return s.scanFloat(&pos) 572 case ch1 == 'B': 573 s.r.incAsLongAs(isIdentChar) 574 return identifier, pos, s.r.data(&pos) 575 } 576 } 577 578 s.scanDigits() 579 ch0 = s.r.peek() 580 if ch0 == '.' || ch0 == 'e' || ch0 == 'E' { 581 return s.scanFloat(&pos) 582 } 583 584 // Identifiers may begin with a digit but unless quoted may not consist solely of digits. 585 if !s.r.eof() && isIdentChar(ch0) { 586 s.r.incAsLongAs(isIdentChar) 587 return identifier, pos, s.r.data(&pos) 588 } 589 lit = s.r.data(&pos) 590 return 591 } 592 593 func startWithDot(s *Lexer) (tok int, pos Pos, lit string) { 594 pos = s.r.pos() 595 s.r.inc() 596 if s.identifierDot { 597 return int('.'), pos, "." 598 } 599 if isDigit(s.r.peek()) { 600 tok, p, l := s.scanFloat(&pos) 601 if tok == identifier { 602 return invalid, p, l 603 } 604 return tok, p, l 605 } 606 tok, lit = int('.'), "." 607 return 608 } 609 610 func (l *Lexer) scanOct() { 611 l.r.incAsLongAs(func(ch byte) bool { 612 return ch >= '0' && ch <= '7' 613 }) 614 } 615 616 func (l *Lexer) scanHex() { 617 l.r.incAsLongAs(func(ch byte) bool { 618 return ch >= '0' && ch <= '9' || 619 ch >= 'a' && ch <= 'f' || 620 ch >= 'A' && ch <= 'F' 621 }) 622 } 623 624 func (l *Lexer) scanBit() { 625 l.r.incAsLongAs(func(ch byte) bool { 626 return ch == '0' || ch == '1' 627 }) 628 } 629 630 func (l *Lexer) scanFloat(beg *Pos) (tok int, pos Pos, lit string) { 631 l.r.updatePos(*beg) 632 // float = D1 . D2 e D3 633 l.scanDigits() 634 ch0 := l.r.peek() 635 if ch0 == '.' { 636 l.r.inc() 637 l.scanDigits() 638 ch0 = l.r.peek() 639 } 640 if ch0 == 'e' || ch0 == 'E' { 641 l.r.inc() 642 ch0 = l.r.peek() 643 if ch0 == '-' || ch0 == '+' { 644 l.r.inc() 645 } 646 if isDigit(l.r.peek()) { 647 l.scanDigits() 648 tok = floatLit 649 } else { 650 // D1 . D2 e XX when XX is not D3, parse the result to an identifier. 651 // 9e9e = 9e9(float) + e(identifier) 652 // 9est = 9est(identifier) 653 l.r.updatePos(*beg) 654 l.r.incAsLongAs(isIdentChar) 655 tok = identifier 656 } 657 } else { 658 tok = decLit 659 } 660 pos, lit = *beg, l.r.data(beg) 661 return 662 } 663 664 func (l *Lexer) scanDigits() string { 665 pos := l.r.pos() 666 l.r.incAsLongAs(isDigit) 667 return l.r.data(&pos) 668 } 669 670 // scanVersionDigits scans for `min` to `max` digits (range inclusive) used in 671 // `/*!12345 ... */` comments. 672 func (l *Lexer) scanVersionDigits(min, max int) { 673 pos := l.r.pos() 674 for i := 0; i < max; i++ { 675 ch := l.r.peek() 676 if isDigit(ch) { 677 l.r.inc() 678 } else if i < min { 679 l.r.updatePos(pos) 680 return 681 } else { 682 break 683 } 684 } 685 } 686 687 func (l *Lexer) lastErrorAsWarn() { 688 if len(l.errs) == 0 { 689 return 690 } 691 l.warns = append(l.warns, l.errs[len(l.errs)-1]) 692 l.errs = l.errs[:len(l.errs)-1] 693 } 694 695 type reader struct { 696 s string 697 p Pos 698 l int 699 } 700 701 func (r *reader) eof() bool { 702 return r.p.Offset >= r.l 703 } 704 705 // peek() peeks a rune from underlying reader. 706 // if reader meets EOF, it will return 0. to distinguish from 707 // the real 0, the caller should call r.eof() again to check. 708 func (r *reader) peek() byte { 709 if r.eof() { 710 return 0 711 } 712 return r.s[r.p.Offset] 713 } 714 715 // inc increase the position offset of the reader. 716 // peek must be called before calling inc! 717 func (r *reader) inc() { 718 if r.s[r.p.Offset] == '\n' { 719 r.p.Line++ 720 r.p.Col = 0 721 } 722 r.p.Offset++ 723 r.p.Col++ 724 } 725 726 func (r *reader) incN(n int) { 727 for i := 0; i < n; i++ { 728 r.inc() 729 } 730 } 731 732 func (r *reader) readByte() (ch byte) { 733 ch = r.peek() 734 if r.eof() { 735 return 736 } 737 r.inc() 738 return 739 } 740 741 func (r *reader) pos() Pos { 742 return r.p 743 } 744 745 func (r *reader) updatePos(pos Pos) { 746 r.p = pos 747 } 748 749 func (r *reader) data(from *Pos) string { 750 return r.s[from.Offset:r.p.Offset] 751 } 752 753 func (r *reader) incAsLongAs(fn func(b byte) bool) byte { 754 for { 755 ch := r.peek() 756 if !fn(ch) { 757 return ch 758 } 759 if r.eof() { 760 return 0 761 } 762 r.inc() 763 } 764 }