github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/parsers/dialect/postgresql/scanner.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package postgresql 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 "unicode" 22 23 "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" 24 ) 25 26 const eofChar = 0x100 27 28 type Scanner struct { 29 LastToken string 30 LastError error 31 posVarIndex int 32 dialectType dialect.DialectType 33 MysqlSpecialComment *Scanner 34 35 Pos int 36 buf string 37 } 38 39 func NewScanner(dialectType dialect.DialectType, sql string) *Scanner { 40 41 return &Scanner{ 42 buf: sql, 43 } 44 } 45 46 func (s *Scanner) Scan() (int, string) { 47 if s.MysqlSpecialComment != nil { 48 msc := s.MysqlSpecialComment 49 tok, val := msc.Scan() 50 if tok != 0 { 51 return tok, val 52 } 53 s.MysqlSpecialComment = nil 54 } 55 56 s.skipBlank() 57 switch ch := s.cur(); { 58 case ch == '@': 59 tokenID := AT_ID 60 s.skip(1) 61 s.skipBlank() 62 if s.cur() == '@' { 63 tokenID = AT_AT_ID 64 s.skip(1) 65 } else if s.cur() == '\'' || s.cur() == '"' { 66 return int('@'), "" 67 } else if s.cur() == ',' { 68 return tokenID, "" 69 } 70 var tID int 71 var tBytes string 72 if s.cur() == '`' { 73 s.skip(1) 74 tID, tBytes = s.scanLiteralIdentifier() 75 } else if s.cur() == eofChar { 76 return LEX_ERROR, "" 77 } else { 78 tID, tBytes = s.scanIdentifier(true) 79 } 80 if tID == LEX_ERROR { 81 return tID, "" 82 } 83 return tokenID, tBytes 84 case isLetter(ch): 85 if ch == 'X' || ch == 'x' { 86 if s.peek(1) == '\'' { 87 s.skip(2) 88 return s.scanHex() 89 } 90 } 91 if ch == 'B' || ch == 'b' { 92 if s.peek(1) == '\'' { 93 s.skip(2) 94 return s.scanBitLiteral() 95 } 96 } 97 return s.scanIdentifier(false) 98 case isDigit(ch): 99 return s.scanNumber() 100 case ch == ':': 101 if s.peek(1) == '=' { 102 s.skip(2) 103 return ASSIGNMENT, "" 104 } 105 // Like mysql -h ::1 ? 106 return s.scanBindVar() 107 case ch == ';': 108 s.skip(1) 109 return ';', "" 110 case ch == '.' && isDigit(s.peek(1)): 111 return s.scanNumber() 112 case ch == '/': 113 s.skip(1) 114 switch s.cur() { 115 case '/': 116 s.skip(1) 117 id, str := s.scanCommentTypeLine(2) 118 if id == LEX_ERROR { 119 return id, str 120 } 121 return s.Scan() 122 case '*': 123 s.skip(1) 124 switch { 125 case s.cur() == '!' && s.dialectType == dialect.MYSQL: 126 // TODO: ExtractMysqlComment 127 return s.scanMySQLSpecificComment() 128 default: 129 id, str := s.scanCommentTypeBlock() 130 if id == LEX_ERROR { 131 return id, str 132 } 133 return s.Scan() 134 } 135 default: 136 return int(ch), "" 137 } 138 default: 139 return s.stepBackOneChar(ch) 140 } 141 } 142 143 func (s *Scanner) stepBackOneChar(ch uint16) (int, string) { 144 s.skip(1) 145 switch ch { 146 case eofChar: 147 return 0, "" 148 case '=', ',', '(', ')', '+', '*', '%', '^', '~': 149 return int(ch), "" 150 case '&': 151 if s.cur() == '&' { 152 s.skip(1) 153 return AND, "" 154 } 155 return int(ch), "" 156 case '|': 157 if s.cur() == '|' { 158 s.skip(1) 159 return PIPE_CONCAT, "" 160 } 161 return int(ch), "" 162 case '?': 163 // mysql's situation 164 s.posVarIndex++ 165 buf := make([]byte, 0, 8) 166 buf = append(buf, ":v"...) 167 buf = strconv.AppendInt(buf, int64(s.posVarIndex), 10) 168 return VALUE_ARG, string(buf) 169 case '.': 170 return int(ch), "" 171 case '#': 172 return s.scanCommentTypeLine(1) 173 case '-': 174 switch s.cur() { 175 case '-': 176 nextChar := s.peek(1) 177 if nextChar == ' ' || nextChar == '\n' || nextChar == '\t' || nextChar == '\r' || nextChar == eofChar { 178 s.skip(1) 179 return s.scanCommentTypeLine(2) 180 } 181 case '>': 182 s.skip(1) 183 // TODO: 184 // JSON_UNQUOTE_EXTRACT_OP 185 // JSON_EXTRACT_OP 186 return 0, "" 187 } 188 return int(ch), "" 189 case '<': 190 switch s.cur() { 191 case '>': 192 s.skip(1) 193 return NE, "" 194 case '<': 195 s.skip(1) 196 return SHIFT_LEFT, "" 197 case '=': 198 s.skip(1) 199 switch s.cur() { 200 case '>': 201 s.skip(1) 202 return NULL_SAFE_EQUAL, "" 203 default: 204 return LE, "" 205 } 206 default: 207 return int(ch), "" 208 } 209 case '>': 210 switch s.cur() { 211 case '=': 212 s.skip(1) 213 return GE, "" 214 case '>': 215 s.skip(1) 216 return SHIFT_RIGHT, "" 217 default: 218 return int(ch), "" 219 } 220 case '!': 221 if s.cur() == '=' { 222 s.skip(1) 223 return NE, "" 224 } 225 return int(ch), "" 226 case '\'', '"': 227 return s.scanString(ch, STRING) 228 case '`': 229 return s.scanLiteralIdentifier() 230 default: 231 return LEX_ERROR, string(byte(ch)) 232 } 233 } 234 235 // scanString scans a string surrounded by the given `delim`, which can be 236 // either single or double quotes. Assumes that the given delimiter has just 237 // been scanned. If the skin contains any escape sequences, this function 238 // will fall back to scanStringSlow 239 func (s *Scanner) scanString(delim uint16, typ int) (int, string) { 240 start := s.Pos 241 242 for { 243 switch s.cur() { 244 case delim: 245 if s.peek(1) != delim { 246 s.skip(1) 247 return typ, s.buf[start : s.Pos-1] 248 } 249 fallthrough 250 251 case '\\': 252 var buffer strings.Builder 253 buffer.WriteString(s.buf[start:s.Pos]) 254 return s.scanStringSlow(&buffer, delim, typ) 255 256 case eofChar: 257 return LEX_ERROR, s.buf[start:s.Pos] 258 } 259 260 s.skip(1) 261 } 262 } 263 264 // scanString scans a string surrounded by the given `delim` and containing escape 265 // sequencse. The given `buffer` contains the contents of the string that have 266 // been scanned so far. 267 func (s *Scanner) scanStringSlow(buffer *strings.Builder, delim uint16, typ int) (int, string) { 268 for { 269 ch := s.cur() 270 if ch == eofChar { 271 // Unterminated string. 272 return LEX_ERROR, buffer.String() 273 } 274 275 if ch != delim && ch != '\\' { 276 start := s.Pos 277 for ; s.Pos < len(s.buf); s.Pos++ { 278 ch = uint16(s.buf[s.Pos]) 279 if ch == delim || ch == '\\' { 280 break 281 } 282 } 283 284 buffer.WriteString(s.buf[start:s.Pos]) 285 if s.Pos >= len(s.buf) { 286 s.skip(1) 287 continue 288 } 289 } 290 s.skip(1) 291 292 if ch == '\\' { 293 ch = s.cur() 294 switch ch { 295 case eofChar: 296 return LEX_ERROR, buffer.String() 297 case 'n': 298 ch = '\n' 299 case '0': 300 ch = '\x00' 301 case 'b': 302 ch = 8 303 case 'Z': 304 ch = 26 305 case 'r': 306 ch = '\r' 307 case 't': 308 ch = '\t' 309 case '%', '_': 310 buffer.WriteByte(byte('\\')) 311 continue 312 case '\\', delim: 313 default: 314 continue 315 } 316 } else if ch == delim && s.cur() != delim { 317 break 318 } 319 buffer.WriteByte(byte(ch)) 320 s.skip(1) 321 } 322 323 return typ, buffer.String() 324 } 325 326 // scanLiteralIdentifier scans an identifier enclosed by backticks. If the identifier 327 // is a simple literal, it'll be returned as a slice of the input buffer. If the identifier 328 // contains escape sequences, this function will fall back to scanLiteralIdentifierSlow 329 func (s *Scanner) scanLiteralIdentifier() (int, string) { 330 start := s.Pos 331 for { 332 switch s.cur() { 333 case '`': 334 if s.peek(1) != '`' { 335 if s.Pos == start { 336 return LEX_ERROR, "" 337 } 338 s.skip(1) 339 return ID, s.buf[start : s.Pos-1] 340 } 341 342 var buf strings.Builder 343 buf.WriteString(s.buf[start:s.Pos]) 344 s.skip(1) 345 return s.scanLiteralIdentifierSlow(&buf) 346 case eofChar: 347 // Premature EOF. 348 return LEX_ERROR, s.buf[start:s.Pos] 349 default: 350 s.skip(1) 351 } 352 } 353 } 354 355 // scanLiteralIdentifierSlow scans an identifier surrounded by backticks which may 356 // contain escape sequences instead of it. This method is only called from 357 // scanLiteralIdentifier once the first escape sequence is found in the identifier. 358 // The provided `buf` contains the contents of the identifier that have been scanned 359 // so far. 360 func (s *Scanner) scanLiteralIdentifierSlow(buf *strings.Builder) (int, string) { 361 backTickSeen := true 362 for { 363 if backTickSeen { 364 if s.cur() != '`' { 365 break 366 } 367 backTickSeen = false 368 buf.WriteByte('`') 369 s.skip(1) 370 continue 371 } 372 // The previous char was not a backtick. 373 switch s.cur() { 374 case '`': 375 backTickSeen = true 376 case eofChar: 377 // Premature EOF. 378 return LEX_ERROR, buf.String() 379 default: 380 buf.WriteByte(byte(s.cur())) 381 // keep scanning 382 } 383 s.skip(1) 384 } 385 return ID, buf.String() 386 } 387 388 // scanCommentTypeBlock scans a '/*' delimited comment; 389 // assumes the opening prefix has already been scanned 390 func (s *Scanner) scanCommentTypeBlock() (int, string) { 391 start := s.Pos - 2 392 for { 393 if s.cur() == '*' { 394 s.skip(1) 395 if s.cur() == '/' { 396 s.skip(1) 397 break 398 } 399 continue 400 } 401 if s.cur() == eofChar { 402 return LEX_ERROR, s.buf[start:s.Pos] 403 } 404 s.skip(1) 405 } 406 return COMMENT, s.buf[start:s.Pos] 407 } 408 409 // scanMySQLSpecificComment scans a MySQL comment pragma, which always starts with '//*` 410 func (s *Scanner) scanMySQLSpecificComment() (int, string) { 411 start := s.Pos - 3 412 for { 413 if s.cur() == '*' { 414 s.skip(1) 415 if s.cur() == '/' { 416 s.skip(1) 417 break 418 } 419 continue 420 } 421 if s.cur() == eofChar { 422 return LEX_ERROR, s.buf[start:s.Pos] 423 } 424 s.skip(1) 425 } 426 427 _, sql := ExtractMysqlComment(s.buf[start:s.Pos]) 428 429 s.MysqlSpecialComment = NewScanner(s.dialectType, sql) 430 431 return s.Scan() 432 } 433 434 // ExtractMysqlComment extracts the version and SQL from a comment-only query 435 // such as /*!50708 sql here */ 436 func ExtractMysqlComment(sql string) (string, string) { 437 sql = sql[3 : len(sql)-2] 438 439 digitCount := 0 440 endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool { 441 digitCount++ 442 return !unicode.IsDigit(c) || digitCount == 6 443 }) 444 if endOfVersionIndex < 0 { 445 return "", "" 446 } 447 if endOfVersionIndex < 5 { 448 endOfVersionIndex = 0 449 } 450 version := sql[0:endOfVersionIndex] 451 innerSQL := strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace) 452 453 return version, innerSQL 454 } 455 456 // scanCommentTypeLine scans a SQL line-comment, which is applied until the end 457 // of the line. The given prefix length varies based on whether the comment 458 // is started with '//', '--' or '#'. 459 func (s *Scanner) scanCommentTypeLine(prefixLen int) (int, string) { 460 start := s.Pos - prefixLen 461 for s.cur() != eofChar { 462 if s.cur() == '\n' { 463 s.skip(1) 464 break 465 } 466 s.skip(1) 467 } 468 return COMMENT, s.buf[start:s.Pos] 469 } 470 471 // ? 472 // scanBindVar scans a bind variable; assumes a ':' has been scanned right before 473 func (s *Scanner) scanBindVar() (int, string) { 474 start := s.Pos 475 token := VALUE_ARG 476 477 s.skip(1) 478 if s.cur() == ':' { 479 token = LIST_ARG 480 s.skip(1) 481 } 482 if !isLetter(s.cur()) { 483 return LEX_ERROR, s.buf[start:s.Pos] 484 } 485 for { 486 ch := s.cur() 487 if !isLetter(ch) && !isDigit(ch) && ch != '.' { 488 break 489 } 490 s.skip(1) 491 } 492 return token, s.buf[start:s.Pos] 493 } 494 495 // scanNumber scans any SQL numeric literal, either floating point or integer 496 func (s *Scanner) scanNumber() (int, string) { 497 start := s.Pos 498 token := INTEGRAL 499 500 if s.cur() == '.' { 501 token = FLOAT 502 s.skip(1) 503 s.scanMantissa(10) 504 goto exponent 505 } 506 507 // 0x construct. 508 if s.cur() == '0' { 509 s.skip(1) 510 if s.cur() == 'x' || s.cur() == 'X' { 511 token = HEXNUM 512 s.skip(1) 513 s.scanMantissa(16) 514 goto exit 515 } else if s.cur() == 'b' || s.cur() == 'B' { 516 token = BIT_LITERAL 517 s.skip(1) 518 s.scanMantissa(2) 519 goto exit 520 } 521 } 522 523 s.scanMantissa(10) 524 525 if s.cur() == '.' { 526 token = FLOAT 527 s.skip(1) 528 s.scanMantissa(10) 529 } 530 531 exponent: 532 if s.cur() == 'e' || s.cur() == 'E' { 533 if s.peek(1) == '+' || s.peek(1) == '-' { 534 token = FLOAT 535 s.skip(2) 536 } else if digitVal(s.peek(1)) < 10 { 537 token = FLOAT 538 s.skip(1) 539 } else { 540 goto exit 541 } 542 s.scanMantissa(10) 543 } 544 545 exit: 546 if isLetter(s.cur()) { 547 // TODO: optimize 548 token = ID 549 s.scanIdentifier(false) 550 } 551 552 return token, strings.ToLower(s.buf[start:s.Pos]) 553 } 554 555 func (s *Scanner) scanIdentifier(isVariable bool) (int, string) { 556 start := s.Pos 557 s.skip(1) 558 559 for { 560 ch := s.cur() 561 if !isLetter(ch) && !isDigit(ch) && ch != '@' && !(isVariable && isCarat(ch)) { 562 break 563 } 564 if ch == '@' { 565 isVariable = true 566 } 567 s.skip(1) 568 } 569 keywordName := s.buf[start:s.Pos] 570 lower := strings.ToLower(keywordName) 571 if keywordID, found := keywords[lower]; found { 572 return keywordID, lower 573 } 574 // dual must always be case-insensitive 575 if lower == "dual" { 576 return ID, lower 577 } 578 return ID, lower 579 } 580 581 func (s *Scanner) scanBitLiteral() (int, string) { 582 start := s.Pos 583 s.scanMantissa(2) 584 bit := s.buf[start:s.Pos] 585 if s.cur() != '\'' { 586 return LEX_ERROR, bit 587 } 588 s.skip(1) 589 return BIT_LITERAL, bit 590 } 591 592 func (s *Scanner) scanHex() (int, string) { 593 start := s.Pos 594 s.scanMantissa(16) 595 hex := s.buf[start:s.Pos] 596 if s.cur() != '\'' { 597 return LEX_ERROR, hex 598 } 599 s.skip(1) 600 if len(hex)%2 != 0 { 601 return LEX_ERROR, hex 602 } 603 return HEXNUM, hex 604 } 605 606 func (s *Scanner) scanMantissa(base int) { 607 for digitVal(s.cur()) < base { 608 s.skip(1) 609 } 610 } 611 612 // PositionedErr holds context related to parser errros 613 type PositionedErr struct { 614 Err string 615 Pos int 616 Near string 617 } 618 619 func (p PositionedErr) Error() string { 620 if p.Near != "" { 621 return fmt.Sprintf("%s at position %v near '%s';", p.Err, p.Pos, p.Near) 622 } 623 return fmt.Sprintf("%s at position %v;", p.Err, p.Pos) 624 } 625 626 func (s *Scanner) skipBlank() { 627 ch := s.cur() 628 for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' { 629 s.skip(1) 630 ch = s.cur() 631 } 632 } 633 634 func (s *Scanner) cur() uint16 { 635 return s.peek(0) 636 } 637 638 func (s *Scanner) skip(dist int) { 639 s.Pos += dist 640 } 641 642 func (s *Scanner) peek(dist int) uint16 { 643 if s.Pos+dist >= len(s.buf) { 644 return eofChar 645 } 646 return uint16(s.buf[s.Pos+dist]) 647 } 648 649 func isLetter(ch uint16) bool { 650 return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '$' 651 } 652 653 func isCarat(ch uint16) bool { 654 return ch == '.' || ch == '"' || ch == '`' || ch == '\'' 655 } 656 657 func digitVal(ch uint16) int { 658 switch { 659 case '0' <= ch && ch <= '9': 660 return int(ch) - '0' 661 case 'a' <= ch && ch <= 'f': 662 return int(ch) - 'a' + 10 663 case 'A' <= ch && ch <= 'F': 664 return int(ch) - 'A' + 10 665 } 666 return 16 // larger than any legal digit val 667 } 668 669 func isDigit(ch uint16) bool { 670 return '0' <= ch && ch <= '9' 671 }