github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/parsers/dialect/mysql/scanner.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 "unicode" 22 23 "github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect" 24 ) 25 26 const eofChar = 0x100 27 28 type Scanner struct { 29 LastToken string 30 LastError error 31 posVarIndex int 32 dialectType dialect.DialectType 33 MysqlSpecialComment *Scanner 34 35 Pos int 36 Line int 37 Col int 38 PrePos int 39 buf string 40 } 41 42 func NewScanner(dialectType dialect.DialectType, sql string) *Scanner { 43 44 return &Scanner{ 45 buf: sql, 46 } 47 } 48 49 func (s *Scanner) Scan() (int, string) { 50 if s.MysqlSpecialComment != nil { 51 msc := s.MysqlSpecialComment 52 tok, val := msc.Scan() 53 if tok != 0 { 54 return tok, val 55 } 56 s.MysqlSpecialComment = nil 57 } 58 s.PrePos = s.Pos 59 s.skipBlank() 60 switch ch := s.cur(); { 61 case ch == '@': 62 tokenID := AT_ID 63 s.inc() 64 s.skipBlank() 65 if s.cur() == '@' { 66 tokenID = AT_AT_ID 67 s.inc() 68 } else if s.cur() == '\'' || s.cur() == '"' { 69 return int('@'), "" 70 } else if s.cur() == ',' { 71 return tokenID, "" 72 } 73 var tID int 74 var tBytes string 75 if s.cur() == '`' { 76 s.inc() 77 tID, tBytes = s.scanLiteralIdentifier() 78 } else if s.cur() == eofChar { 79 return LEX_ERROR, "" 80 } else { 81 tID, tBytes = s.scanIdentifier(true) 82 } 83 if tID == LEX_ERROR { 84 return tID, "" 85 } 86 return tokenID, tBytes 87 case isLetter(ch): 88 if ch == 'X' || ch == 'x' { 89 if s.peek(1) == '\'' { 90 s.incN(2) 91 return s.scanHex() 92 } 93 } 94 if ch == 'B' || ch == 'b' { 95 if s.peek(1) == '\'' { 96 s.incN(2) 97 return s.scanBitLiteral() 98 } 99 } 100 return s.scanIdentifier(false) 101 case isDigit(ch): 102 return s.scanNumber() 103 case ch == ':': 104 if s.peek(1) == '=' { 105 s.incN(2) 106 return ASSIGNMENT, "" 107 } 108 // Like mysql -h ::1 ? 109 return s.scanBindVar() 110 case ch == ';': 111 s.inc() 112 return ';', "" 113 case ch == '.' && isDigit(s.peek(1)): 114 return s.scanNumber() 115 case ch == '/': 116 s.inc() 117 switch s.cur() { 118 case '/': 119 s.inc() 120 id, str := s.scanCommentTypeLine(2) 121 if id == LEX_ERROR { 122 return id, str 123 } 124 return s.Scan() 125 case '*': 126 s.inc() 127 switch { 128 case s.cur() == '!' && s.dialectType == dialect.MYSQL: 129 // TODO: ExtractMysqlComment 130 return s.scanMySQLSpecificComment() 131 default: 132 id, str := s.scanCommentTypeBlock() 133 if id == LEX_ERROR { 134 return id, str 135 } 136 return s.Scan() 137 } 138 default: 139 return int(ch), "" 140 } 141 default: 142 return s.stepBackOneChar(ch) 143 } 144 } 145 146 func (s *Scanner) stepBackOneChar(ch uint16) (int, string) { 147 s.inc() 148 switch ch { 149 case eofChar: 150 return 0, "" 151 case '=', ',', '(', ')', '+', '*', '%', '^', '~', '{', '}': 152 return int(ch), "" 153 case '&': 154 if s.cur() == '&' { 155 s.inc() 156 return AND, "" 157 } 158 return int(ch), "" 159 case '|': 160 if s.cur() == '|' { 161 s.inc() 162 return PIPE_CONCAT, "" 163 } 164 return int(ch), "" 165 case '?': 166 // mysql's situation 167 s.posVarIndex++ 168 buf := make([]byte, 0, 8) 169 buf = append(buf, ":v"...) 170 buf = strconv.AppendInt(buf, int64(s.posVarIndex), 10) 171 return VALUE_ARG, string(buf) 172 case '.': 173 return int(ch), "" 174 case '#': 175 return s.scanCommentTypeLine(1) 176 case '-': 177 switch s.cur() { 178 case '-': 179 nextChar := s.peek(1) 180 if nextChar == ' ' || nextChar == '\n' || nextChar == '\t' || nextChar == '\r' || nextChar == eofChar { 181 s.inc() 182 id, str := s.scanCommentTypeLine(2) 183 if id == LEX_ERROR { 184 return id, str 185 } 186 return s.Scan() 187 } 188 case '>': 189 s.inc() 190 return ARROW, "" 191 } 192 return int(ch), "" 193 case '<': 194 switch s.cur() { 195 case '>': 196 s.inc() 197 return NE, "" 198 case '<': 199 s.inc() 200 return SHIFT_LEFT, "" 201 case '=': 202 s.inc() 203 switch s.cur() { 204 case '>': 205 s.inc() 206 return NULL_SAFE_EQUAL, "" 207 default: 208 return LE, "" 209 } 210 default: 211 return int(ch), "" 212 } 213 case '>': 214 switch s.cur() { 215 case '=': 216 s.inc() 217 return GE, "" 218 case '>': 219 s.inc() 220 return SHIFT_RIGHT, "" 221 default: 222 return int(ch), "" 223 } 224 case '!': 225 if s.cur() == '=' { 226 s.inc() 227 return NE, "" 228 } 229 return int(ch), "" 230 case '\'', '"': 231 return s.scanString(ch, STRING) 232 case '`': 233 return s.scanLiteralIdentifier() 234 default: 235 return LEX_ERROR, string(byte(ch)) 236 } 237 } 238 239 func (s *Scanner) scanString(delim uint16, typ int) (int, string) { 240 ch := s.cur() 241 buf := new(strings.Builder) 242 for s.Pos < len(s.buf) { 243 if ch == delim { 244 s.inc() 245 if s.cur() != delim { 246 return typ, buf.String() 247 } 248 } else if ch == '\\' { 249 ch = handleEscape(s, buf) 250 if ch == eofChar { 251 break 252 } 253 } 254 buf.WriteByte(byte(ch)) 255 if s.Pos < len(s.buf) { 256 s.inc() 257 ch = s.cur() 258 } 259 } 260 return LEX_ERROR, buf.String() 261 } 262 263 func handleEscape(s *Scanner, buf *strings.Builder) uint16 { 264 s.inc() 265 ch0 := s.cur() 266 switch ch0 { 267 case 'n': 268 ch0 = '\n' 269 case '0': 270 ch0 = 0 271 case 'b': 272 ch0 = 8 273 case 'Z': 274 ch0 = 26 275 case 'r': 276 ch0 = '\r' 277 case 't': 278 ch0 = '\t' 279 case '%', '_': 280 buf.WriteByte('\\') 281 } 282 return ch0 283 } 284 285 // scanLiteralIdentifier scans an identifier enclosed by backticks. If the identifier 286 // is a simple literal, it'll be returned as a slice of the input buffer. If the identifier 287 // contains escape sequences, this function will fall back to scanLiteralIdentifierSlow 288 func (s *Scanner) scanLiteralIdentifier() (int, string) { 289 start := s.Pos 290 for { 291 switch s.cur() { 292 case '`': 293 if s.peek(1) != '`' { 294 if s.Pos == start { 295 return LEX_ERROR, "" 296 } 297 s.inc() 298 return ID, s.buf[start : s.Pos-1] 299 } 300 301 var buf strings.Builder 302 buf.WriteString(s.buf[start:s.Pos]) 303 s.inc() 304 return s.scanLiteralIdentifierSlow(&buf) 305 case eofChar: 306 // Premature EOF. 307 return LEX_ERROR, s.buf[start:s.Pos] 308 default: 309 s.inc() 310 } 311 } 312 } 313 314 // scanLiteralIdentifierSlow scans an identifier surrounded by backticks which may 315 // contain escape sequences instead of it. This method is only called from 316 // scanLiteralIdentifier once the first escape sequence is found in the identifier. 317 // The provided `buf` contains the contents of the identifier that have been scanned 318 // so far. 319 func (s *Scanner) scanLiteralIdentifierSlow(buf *strings.Builder) (int, string) { 320 backTickSeen := true 321 for { 322 if backTickSeen { 323 if s.cur() != '`' { 324 break 325 } 326 backTickSeen = false 327 buf.WriteByte('`') 328 s.inc() 329 continue 330 } 331 // The previous char was not a backtick. 332 switch s.cur() { 333 case '`': 334 backTickSeen = true 335 case eofChar: 336 // Premature EOF. 337 return LEX_ERROR, buf.String() 338 default: 339 buf.WriteByte(byte(s.cur())) 340 // keep scanning 341 } 342 s.inc() 343 } 344 return ID, buf.String() 345 } 346 347 // scanCommentTypeBlock scans a '/*' delimited comment; 348 // assumes the opening prefix has already been scanned 349 func (s *Scanner) scanCommentTypeBlock() (int, string) { 350 start := s.Pos - 2 351 for { 352 if s.cur() == '*' { 353 s.inc() 354 if s.cur() == '/' { 355 s.inc() 356 break 357 } 358 continue 359 } 360 if s.cur() == eofChar { 361 return LEX_ERROR, s.buf[start:s.Pos] 362 } 363 s.inc() 364 } 365 return COMMENT, s.buf[start:s.Pos] 366 } 367 368 // scanMySQLSpecificComment scans a MySQL comment pragma, which always starts with '//*` 369 func (s *Scanner) scanMySQLSpecificComment() (int, string) { 370 start := s.Pos - 3 371 for { 372 if s.cur() == '*' { 373 s.inc() 374 if s.cur() == '/' { 375 s.inc() 376 break 377 } 378 continue 379 } 380 if s.cur() == eofChar { 381 return LEX_ERROR, s.buf[start:s.Pos] 382 } 383 s.inc() 384 } 385 386 _, sql := ExtractMysqlComment(s.buf[start:s.Pos]) 387 388 s.MysqlSpecialComment = NewScanner(s.dialectType, sql) 389 390 return s.Scan() 391 } 392 393 // ExtractMysqlComment extracts the version and SQL from a comment-only query 394 // such as /*!50708 sql here */ 395 func ExtractMysqlComment(sql string) (string, string) { 396 sql = sql[3 : len(sql)-2] 397 398 digitCount := 0 399 endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool { 400 digitCount++ 401 return !unicode.IsDigit(c) || digitCount == 6 402 }) 403 if endOfVersionIndex < 0 { 404 return "", "" 405 } 406 if endOfVersionIndex < 5 { 407 endOfVersionIndex = 0 408 } 409 version := sql[0:endOfVersionIndex] 410 innerSQL := strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace) 411 412 return version, innerSQL 413 } 414 415 // scanCommentTypeLine scans a SQL line-comment, which is applied until the end 416 // of the line. The given prefix length varies based on whether the comment 417 // is started with '//', '--' or '#'. 418 func (s *Scanner) scanCommentTypeLine(prefixLen int) (int, string) { 419 start := s.Pos - prefixLen 420 for s.cur() != eofChar { 421 if s.cur() == '\n' { 422 s.inc() 423 break 424 } 425 s.inc() 426 } 427 return COMMENT, s.buf[start:s.Pos] 428 } 429 430 // ? 431 // scanBindVar scans a bind variable; assumes a ':' has been scanned right before 432 func (s *Scanner) scanBindVar() (int, string) { 433 start := s.Pos 434 token := VALUE_ARG 435 436 s.inc() 437 if s.cur() == ':' { 438 token = LIST_ARG 439 s.inc() 440 } 441 if !isLetter(s.cur()) { 442 return LEX_ERROR, s.buf[start:s.Pos] 443 } 444 for { 445 ch := s.cur() 446 if !isLetter(ch) && !isDigit(ch) && ch != '.' { 447 break 448 } 449 s.inc() 450 } 451 return token, s.buf[start:s.Pos] 452 } 453 454 // scanNumber scans any SQL numeric literal, either floating point or integer 455 func (s *Scanner) scanNumber() (int, string) { 456 start := s.Pos 457 token := INTEGRAL 458 459 if s.cur() == '.' { 460 token = FLOAT 461 s.inc() 462 s.scanMantissa(10) 463 goto exponent 464 } 465 466 // 0x construct. 467 if s.cur() == '0' { 468 s.inc() 469 if s.cur() == 'x' || s.cur() == 'X' { 470 token = HEXNUM 471 s.inc() 472 s.scanMantissa(16) 473 goto exit 474 } else if s.cur() == 'b' || s.cur() == 'B' { 475 token = BIT_LITERAL 476 s.inc() 477 s.scanMantissa(2) 478 goto exit 479 } 480 } 481 482 s.scanMantissa(10) 483 484 if s.cur() == '.' { 485 token = FLOAT 486 s.inc() 487 s.scanMantissa(10) 488 } 489 490 exponent: 491 if s.cur() == 'e' || s.cur() == 'E' { 492 if s.peek(1) == '+' || s.peek(1) == '-' { 493 token = FLOAT 494 s.incN(2) 495 } else if digitVal(s.peek(1)) < 10 { 496 token = FLOAT 497 s.inc() 498 } else { 499 goto exit 500 } 501 s.scanMantissa(10) 502 } 503 504 exit: 505 if isLetter(s.cur()) { 506 // TODO: optimize 507 token = ID 508 s.scanIdentifier(false) 509 } 510 511 return token, strings.ToLower(s.buf[start:s.Pos]) 512 } 513 514 func (s *Scanner) scanIdentifier(isVariable bool) (int, string) { 515 start := s.Pos 516 s.inc() 517 518 for { 519 ch := s.cur() 520 if !isLetter(ch) && !isDigit(ch) && ch != '@' && !(isVariable && isCarat(ch)) { 521 break 522 } 523 if ch == '@' { 524 break 525 } 526 s.inc() 527 } 528 keywordName := s.buf[start:s.Pos] 529 lower := strings.ToLower(keywordName) 530 if keywordID, found := keywords[lower]; found { 531 return keywordID, lower 532 } 533 // dual must always be case-insensitive 534 if lower == "dual" { 535 return ID, lower 536 } 537 return ID, lower 538 } 539 540 func (s *Scanner) scanBitLiteral() (int, string) { 541 start := s.Pos 542 s.scanMantissa(2) 543 bit := s.buf[start:s.Pos] 544 if s.cur() != '\'' { 545 return LEX_ERROR, bit 546 } 547 s.inc() 548 return BIT_LITERAL, bit 549 } 550 551 func (s *Scanner) scanHex() (int, string) { 552 start := s.Pos 553 s.scanMantissa(16) 554 hex := s.buf[start:s.Pos] 555 if s.cur() != '\'' { 556 return LEX_ERROR, hex 557 } 558 s.inc() 559 if len(hex)%2 != 0 { 560 return LEX_ERROR, hex 561 } 562 return HEXNUM, hex 563 } 564 565 func (s *Scanner) scanMantissa(base int) { 566 for digitVal(s.cur()) < base { 567 s.inc() 568 } 569 } 570 571 // PositionedErr holds context related to parser errros 572 type PositionedErr struct { 573 Err string 574 Line int 575 Col int 576 Near string 577 LenStr string 578 } 579 580 func (p PositionedErr) Error() string { 581 return fmt.Sprintf("%s at line %d column %d near \"%s\"%s;", p.Err, p.Line+1, p.Col, p.Near, p.LenStr) 582 } 583 584 func (s *Scanner) skipBlank() { 585 ch := s.cur() 586 for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' { 587 s.inc() 588 ch = s.cur() 589 } 590 } 591 592 func (s *Scanner) cur() uint16 { 593 return s.peek(0) 594 } 595 596 func (s *Scanner) inc() { 597 if s.Pos >= len(s.buf) { 598 return 599 } 600 if s.buf[s.Pos] == '\n' { 601 s.Line++ 602 s.Col = 0 603 } 604 s.Pos++ 605 s.Col++ 606 } 607 608 func (s *Scanner) incN(dist int) { 609 for i := 0; i < dist; i++ { 610 s.inc() 611 } 612 } 613 614 func (s *Scanner) peek(dist int) uint16 { 615 if s.Pos+dist >= len(s.buf) { 616 return eofChar 617 } 618 return uint16(s.buf[s.Pos+dist]) 619 } 620 621 func isLetter(ch uint16) bool { 622 return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '$' 623 } 624 625 func isCarat(ch uint16) bool { 626 return ch == '.' || ch == '"' || ch == '`' || ch == '\'' 627 } 628 629 func digitVal(ch uint16) int { 630 switch { 631 case '0' <= ch && ch <= '9': 632 return int(ch) - '0' 633 case 'a' <= ch && ch <= 'f': 634 return int(ch) - 'a' + 10 635 case 'A' <= ch && ch <= 'F': 636 return int(ch) - 'A' + 10 637 } 638 return 16 // larger than any legal digit val 639 } 640 641 func isDigit(ch uint16) bool { 642 return '0' <= ch && ch <= '9' 643 }