github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/parsers/dialect/mysql/scanner.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"strings"
    21  	"unicode"
    22  
    23  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
    24  )
    25  
    26  const eofChar = 0x100
    27  
    28  type Scanner struct {
    29  	LastToken           string
    30  	LastError           error
    31  	posVarIndex         int
    32  	dialectType         dialect.DialectType
    33  	MysqlSpecialComment *Scanner
    34  
    35  	Pos    int
    36  	Line   int
    37  	Col    int
    38  	PrePos int
    39  	buf    string
    40  }
    41  
    42  func NewScanner(dialectType dialect.DialectType, sql string) *Scanner {
    43  
    44  	return &Scanner{
    45  		buf: sql,
    46  	}
    47  }
    48  
    49  func (s *Scanner) Scan() (int, string) {
    50  	if s.MysqlSpecialComment != nil {
    51  		msc := s.MysqlSpecialComment
    52  		tok, val := msc.Scan()
    53  		if tok != 0 {
    54  			return tok, val
    55  		}
    56  		s.MysqlSpecialComment = nil
    57  	}
    58  	s.PrePos = s.Pos
    59  	s.skipBlank()
    60  	switch ch := s.cur(); {
    61  	case ch == '@':
    62  		tokenID := AT_ID
    63  		s.inc()
    64  		s.skipBlank()
    65  		if s.cur() == '@' {
    66  			tokenID = AT_AT_ID
    67  			s.inc()
    68  		} else if s.cur() == '\'' || s.cur() == '"' {
    69  			return int('@'), ""
    70  		} else if s.cur() == ',' {
    71  			return tokenID, ""
    72  		}
    73  		var tID int
    74  		var tBytes string
    75  		if s.cur() == '`' {
    76  			s.inc()
    77  			tID, tBytes = s.scanLiteralIdentifier()
    78  		} else if s.cur() == eofChar {
    79  			return LEX_ERROR, ""
    80  		} else {
    81  			tID, tBytes = s.scanIdentifier(true)
    82  		}
    83  		if tID == LEX_ERROR {
    84  			return tID, ""
    85  		}
    86  		return tokenID, tBytes
    87  	case isLetter(ch):
    88  		if ch == 'X' || ch == 'x' {
    89  			if s.peek(1) == '\'' {
    90  				s.incN(2)
    91  				return s.scanHex()
    92  			}
    93  		}
    94  		if ch == 'B' || ch == 'b' {
    95  			if s.peek(1) == '\'' {
    96  				s.incN(2)
    97  				return s.scanBitLiteral()
    98  			}
    99  		}
   100  		return s.scanIdentifier(false)
   101  	case isDigit(ch):
   102  		return s.scanNumber()
   103  	case ch == ':':
   104  		if s.peek(1) == '=' {
   105  			s.incN(2)
   106  			return ASSIGNMENT, ""
   107  		}
   108  		// Like mysql -h ::1 ?
   109  		return s.scanBindVar()
   110  	case ch == ';':
   111  		s.inc()
   112  		return ';', ""
   113  	case ch == '.' && isDigit(s.peek(1)):
   114  		return s.scanNumber()
   115  	case ch == '/':
   116  		s.inc()
   117  		switch s.cur() {
   118  		case '/':
   119  			s.inc()
   120  			id, str := s.scanCommentTypeLine(2)
   121  			if id == LEX_ERROR {
   122  				return id, str
   123  			}
   124  			return s.Scan()
   125  		case '*':
   126  			s.inc()
   127  			switch {
   128  			case s.cur() == '!' && s.dialectType == dialect.MYSQL:
   129  				// TODO: ExtractMysqlComment
   130  				return s.scanMySQLSpecificComment()
   131  			default:
   132  				id, str := s.scanCommentTypeBlock()
   133  				if id == LEX_ERROR {
   134  					return id, str
   135  				}
   136  				return s.Scan()
   137  			}
   138  		default:
   139  			return int(ch), ""
   140  		}
   141  	default:
   142  		return s.stepBackOneChar(ch)
   143  	}
   144  }
   145  
   146  func (s *Scanner) stepBackOneChar(ch uint16) (int, string) {
   147  	s.inc()
   148  	switch ch {
   149  	case eofChar:
   150  		return 0, ""
   151  	case '=', ',', '(', ')', '+', '*', '%', '^', '~', '{', '}':
   152  		return int(ch), ""
   153  	case '&':
   154  		if s.cur() == '&' {
   155  			s.inc()
   156  			return AND, ""
   157  		}
   158  		return int(ch), ""
   159  	case '|':
   160  		if s.cur() == '|' {
   161  			s.inc()
   162  			return PIPE_CONCAT, ""
   163  		}
   164  		return int(ch), ""
   165  	case '?':
   166  		// mysql's situation
   167  		s.posVarIndex++
   168  		buf := make([]byte, 0, 8)
   169  		buf = append(buf, ":v"...)
   170  		buf = strconv.AppendInt(buf, int64(s.posVarIndex), 10)
   171  		return VALUE_ARG, string(buf)
   172  	case '.':
   173  		return int(ch), ""
   174  	case '#':
   175  		return s.scanCommentTypeLine(1)
   176  	case '-':
   177  		switch s.cur() {
   178  		case '-':
   179  			nextChar := s.peek(1)
   180  			if nextChar == ' ' || nextChar == '\n' || nextChar == '\t' || nextChar == '\r' || nextChar == eofChar {
   181  				s.inc()
   182  				id, str := s.scanCommentTypeLine(2)
   183  				if id == LEX_ERROR {
   184  					return id, str
   185  				}
   186  				return s.Scan()
   187  			}
   188  		case '>':
   189  			s.inc()
   190  			return ARROW, ""
   191  		}
   192  		return int(ch), ""
   193  	case '<':
   194  		switch s.cur() {
   195  		case '>':
   196  			s.inc()
   197  			return NE, ""
   198  		case '<':
   199  			s.inc()
   200  			return SHIFT_LEFT, ""
   201  		case '=':
   202  			s.inc()
   203  			switch s.cur() {
   204  			case '>':
   205  				s.inc()
   206  				return NULL_SAFE_EQUAL, ""
   207  			default:
   208  				return LE, ""
   209  			}
   210  		default:
   211  			return int(ch), ""
   212  		}
   213  	case '>':
   214  		switch s.cur() {
   215  		case '=':
   216  			s.inc()
   217  			return GE, ""
   218  		case '>':
   219  			s.inc()
   220  			return SHIFT_RIGHT, ""
   221  		default:
   222  			return int(ch), ""
   223  		}
   224  	case '!':
   225  		if s.cur() == '=' {
   226  			s.inc()
   227  			return NE, ""
   228  		}
   229  		return int(ch), ""
   230  	case '\'', '"':
   231  		return s.scanString(ch, STRING)
   232  	case '`':
   233  		return s.scanLiteralIdentifier()
   234  	default:
   235  		return LEX_ERROR, string(byte(ch))
   236  	}
   237  }
   238  
   239  func (s *Scanner) scanString(delim uint16, typ int) (int, string) {
   240  	ch := s.cur()
   241  	buf := new(strings.Builder)
   242  	for s.Pos < len(s.buf) {
   243  		if ch == delim {
   244  			s.inc()
   245  			if s.cur() != delim {
   246  				return typ, buf.String()
   247  			}
   248  		} else if ch == '\\' {
   249  			ch = handleEscape(s, buf)
   250  			if ch == eofChar {
   251  				break
   252  			}
   253  		}
   254  		buf.WriteByte(byte(ch))
   255  		if s.Pos < len(s.buf) {
   256  			s.inc()
   257  			ch = s.cur()
   258  		}
   259  	}
   260  	return LEX_ERROR, buf.String()
   261  }
   262  
   263  func handleEscape(s *Scanner, buf *strings.Builder) uint16 {
   264  	s.inc()
   265  	ch0 := s.cur()
   266  	switch ch0 {
   267  	case 'n':
   268  		ch0 = '\n'
   269  	case '0':
   270  		ch0 = 0
   271  	case 'b':
   272  		ch0 = 8
   273  	case 'Z':
   274  		ch0 = 26
   275  	case 'r':
   276  		ch0 = '\r'
   277  	case 't':
   278  		ch0 = '\t'
   279  	case '%', '_':
   280  		buf.WriteByte('\\')
   281  	}
   282  	return ch0
   283  }
   284  
   285  // scanLiteralIdentifier scans an identifier enclosed by backticks. If the identifier
   286  // is a simple literal, it'll be returned as a slice of the input buffer. If the identifier
   287  // contains escape sequences, this function will fall back to scanLiteralIdentifierSlow
   288  func (s *Scanner) scanLiteralIdentifier() (int, string) {
   289  	start := s.Pos
   290  	for {
   291  		switch s.cur() {
   292  		case '`':
   293  			if s.peek(1) != '`' {
   294  				if s.Pos == start {
   295  					return LEX_ERROR, ""
   296  				}
   297  				s.inc()
   298  				return ID, s.buf[start : s.Pos-1]
   299  			}
   300  
   301  			var buf strings.Builder
   302  			buf.WriteString(s.buf[start:s.Pos])
   303  			s.inc()
   304  			return s.scanLiteralIdentifierSlow(&buf)
   305  		case eofChar:
   306  			// Premature EOF.
   307  			return LEX_ERROR, s.buf[start:s.Pos]
   308  		default:
   309  			s.inc()
   310  		}
   311  	}
   312  }
   313  
   314  // scanLiteralIdentifierSlow scans an identifier surrounded by backticks which may
   315  // contain escape sequences instead of it. This method is only called from
   316  // scanLiteralIdentifier once the first escape sequence is found in the identifier.
   317  // The provided `buf` contains the contents of the identifier that have been scanned
   318  // so far.
   319  func (s *Scanner) scanLiteralIdentifierSlow(buf *strings.Builder) (int, string) {
   320  	backTickSeen := true
   321  	for {
   322  		if backTickSeen {
   323  			if s.cur() != '`' {
   324  				break
   325  			}
   326  			backTickSeen = false
   327  			buf.WriteByte('`')
   328  			s.inc()
   329  			continue
   330  		}
   331  		// The previous char was not a backtick.
   332  		switch s.cur() {
   333  		case '`':
   334  			backTickSeen = true
   335  		case eofChar:
   336  			// Premature EOF.
   337  			return LEX_ERROR, buf.String()
   338  		default:
   339  			buf.WriteByte(byte(s.cur()))
   340  			// keep scanning
   341  		}
   342  		s.inc()
   343  	}
   344  	return ID, buf.String()
   345  }
   346  
   347  // scanCommentTypeBlock scans a '/*' delimited comment;
   348  // assumes the opening prefix has already been scanned
   349  func (s *Scanner) scanCommentTypeBlock() (int, string) {
   350  	start := s.Pos - 2
   351  	for {
   352  		if s.cur() == '*' {
   353  			s.inc()
   354  			if s.cur() == '/' {
   355  				s.inc()
   356  				break
   357  			}
   358  			continue
   359  		}
   360  		if s.cur() == eofChar {
   361  			return LEX_ERROR, s.buf[start:s.Pos]
   362  		}
   363  		s.inc()
   364  	}
   365  	return COMMENT, s.buf[start:s.Pos]
   366  }
   367  
   368  // scanMySQLSpecificComment scans a MySQL comment pragma, which always starts with '//*`
   369  func (s *Scanner) scanMySQLSpecificComment() (int, string) {
   370  	start := s.Pos - 3
   371  	for {
   372  		if s.cur() == '*' {
   373  			s.inc()
   374  			if s.cur() == '/' {
   375  				s.inc()
   376  				break
   377  			}
   378  			continue
   379  		}
   380  		if s.cur() == eofChar {
   381  			return LEX_ERROR, s.buf[start:s.Pos]
   382  		}
   383  		s.inc()
   384  	}
   385  
   386  	_, sql := ExtractMysqlComment(s.buf[start:s.Pos])
   387  
   388  	s.MysqlSpecialComment = NewScanner(s.dialectType, sql)
   389  
   390  	return s.Scan()
   391  }
   392  
   393  // ExtractMysqlComment extracts the version and SQL from a comment-only query
   394  // such as /*!50708 sql here */
   395  func ExtractMysqlComment(sql string) (string, string) {
   396  	sql = sql[3 : len(sql)-2]
   397  
   398  	digitCount := 0
   399  	endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
   400  		digitCount++
   401  		return !unicode.IsDigit(c) || digitCount == 6
   402  	})
   403  	if endOfVersionIndex < 0 {
   404  		return "", ""
   405  	}
   406  	if endOfVersionIndex < 5 {
   407  		endOfVersionIndex = 0
   408  	}
   409  	version := sql[0:endOfVersionIndex]
   410  	innerSQL := strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
   411  
   412  	return version, innerSQL
   413  }
   414  
   415  // scanCommentTypeLine scans a SQL line-comment, which is applied until the end
   416  // of the line. The given prefix length varies based on whether the comment
   417  // is started with '//', '--' or '#'.
   418  func (s *Scanner) scanCommentTypeLine(prefixLen int) (int, string) {
   419  	start := s.Pos - prefixLen
   420  	for s.cur() != eofChar {
   421  		if s.cur() == '\n' {
   422  			s.inc()
   423  			break
   424  		}
   425  		s.inc()
   426  	}
   427  	return COMMENT, s.buf[start:s.Pos]
   428  }
   429  
   430  // ?
   431  // scanBindVar scans a bind variable; assumes a ':' has been scanned right before
   432  func (s *Scanner) scanBindVar() (int, string) {
   433  	start := s.Pos
   434  	token := VALUE_ARG
   435  
   436  	s.inc()
   437  	if s.cur() == ':' {
   438  		token = LIST_ARG
   439  		s.inc()
   440  	}
   441  	if !isLetter(s.cur()) {
   442  		return LEX_ERROR, s.buf[start:s.Pos]
   443  	}
   444  	for {
   445  		ch := s.cur()
   446  		if !isLetter(ch) && !isDigit(ch) && ch != '.' {
   447  			break
   448  		}
   449  		s.inc()
   450  	}
   451  	return token, s.buf[start:s.Pos]
   452  }
   453  
   454  // scanNumber scans any SQL numeric literal, either floating point or integer
   455  func (s *Scanner) scanNumber() (int, string) {
   456  	start := s.Pos
   457  	token := INTEGRAL
   458  
   459  	if s.cur() == '.' {
   460  		token = FLOAT
   461  		s.inc()
   462  		s.scanMantissa(10)
   463  		goto exponent
   464  	}
   465  
   466  	// 0x construct.
   467  	if s.cur() == '0' {
   468  		s.inc()
   469  		if s.cur() == 'x' || s.cur() == 'X' {
   470  			token = HEXNUM
   471  			s.inc()
   472  			s.scanMantissa(16)
   473  			goto exit
   474  		} else if s.cur() == 'b' || s.cur() == 'B' {
   475  			token = BIT_LITERAL
   476  			s.inc()
   477  			s.scanMantissa(2)
   478  			goto exit
   479  		}
   480  	}
   481  
   482  	s.scanMantissa(10)
   483  
   484  	if s.cur() == '.' {
   485  		token = FLOAT
   486  		s.inc()
   487  		s.scanMantissa(10)
   488  	}
   489  
   490  exponent:
   491  	if s.cur() == 'e' || s.cur() == 'E' {
   492  		if s.peek(1) == '+' || s.peek(1) == '-' {
   493  			token = FLOAT
   494  			s.incN(2)
   495  		} else if digitVal(s.peek(1)) < 10 {
   496  			token = FLOAT
   497  			s.inc()
   498  		} else {
   499  			goto exit
   500  		}
   501  		s.scanMantissa(10)
   502  	}
   503  
   504  exit:
   505  	if isLetter(s.cur()) {
   506  		// TODO: optimize
   507  		token = ID
   508  		s.scanIdentifier(false)
   509  	}
   510  
   511  	return token, strings.ToLower(s.buf[start:s.Pos])
   512  }
   513  
   514  func (s *Scanner) scanIdentifier(isVariable bool) (int, string) {
   515  	start := s.Pos
   516  	s.inc()
   517  
   518  	for {
   519  		ch := s.cur()
   520  		if !isLetter(ch) && !isDigit(ch) && ch != '@' && !(isVariable && isCarat(ch)) {
   521  			break
   522  		}
   523  		if ch == '@' {
   524  			break
   525  		}
   526  		s.inc()
   527  	}
   528  	keywordName := s.buf[start:s.Pos]
   529  	lower := strings.ToLower(keywordName)
   530  	if keywordID, found := keywords[lower]; found {
   531  		return keywordID, lower
   532  	}
   533  	// dual must always be case-insensitive
   534  	if lower == "dual" {
   535  		return ID, lower
   536  	}
   537  	return ID, lower
   538  }
   539  
   540  func (s *Scanner) scanBitLiteral() (int, string) {
   541  	start := s.Pos
   542  	s.scanMantissa(2)
   543  	bit := s.buf[start:s.Pos]
   544  	if s.cur() != '\'' {
   545  		return LEX_ERROR, bit
   546  	}
   547  	s.inc()
   548  	return BIT_LITERAL, bit
   549  }
   550  
   551  func (s *Scanner) scanHex() (int, string) {
   552  	start := s.Pos
   553  	s.scanMantissa(16)
   554  	hex := s.buf[start:s.Pos]
   555  	if s.cur() != '\'' {
   556  		return LEX_ERROR, hex
   557  	}
   558  	s.inc()
   559  	if len(hex)%2 != 0 {
   560  		return LEX_ERROR, hex
   561  	}
   562  	return HEXNUM, hex
   563  }
   564  
   565  func (s *Scanner) scanMantissa(base int) {
   566  	for digitVal(s.cur()) < base {
   567  		s.inc()
   568  	}
   569  }
   570  
   571  // PositionedErr holds context related to parser errros
   572  type PositionedErr struct {
   573  	Err    string
   574  	Line   int
   575  	Col    int
   576  	Near   string
   577  	LenStr string
   578  }
   579  
   580  func (p PositionedErr) Error() string {
   581  	return fmt.Sprintf("%s at line %d column %d near \"%s\"%s;", p.Err, p.Line+1, p.Col, p.Near, p.LenStr)
   582  }
   583  
   584  func (s *Scanner) skipBlank() {
   585  	ch := s.cur()
   586  	for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' {
   587  		s.inc()
   588  		ch = s.cur()
   589  	}
   590  }
   591  
   592  func (s *Scanner) cur() uint16 {
   593  	return s.peek(0)
   594  }
   595  
   596  func (s *Scanner) inc() {
   597  	if s.Pos >= len(s.buf) {
   598  		return
   599  	}
   600  	if s.buf[s.Pos] == '\n' {
   601  		s.Line++
   602  		s.Col = 0
   603  	}
   604  	s.Pos++
   605  	s.Col++
   606  }
   607  
   608  func (s *Scanner) incN(dist int) {
   609  	for i := 0; i < dist; i++ {
   610  		s.inc()
   611  	}
   612  }
   613  
   614  func (s *Scanner) peek(dist int) uint16 {
   615  	if s.Pos+dist >= len(s.buf) {
   616  		return eofChar
   617  	}
   618  	return uint16(s.buf[s.Pos+dist])
   619  }
   620  
   621  func isLetter(ch uint16) bool {
   622  	return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '$'
   623  }
   624  
   625  func isCarat(ch uint16) bool {
   626  	return ch == '.' || ch == '"' || ch == '`' || ch == '\''
   627  }
   628  
   629  func digitVal(ch uint16) int {
   630  	switch {
   631  	case '0' <= ch && ch <= '9':
   632  		return int(ch) - '0'
   633  	case 'a' <= ch && ch <= 'f':
   634  		return int(ch) - 'a' + 10
   635  	case 'A' <= ch && ch <= 'F':
   636  		return int(ch) - 'A' + 10
   637  	}
   638  	return 16 // larger than any legal digit val
   639  }
   640  
   641  func isDigit(ch uint16) bool {
   642  	return '0' <= ch && ch <= '9'
   643  }