
     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package postgresql
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"strings"
    21  	"unicode"
    23  	""
    24  )
    26  const eofChar = 0x100
    28  type Scanner struct {
    29  	LastToken           string
    30  	LastError           error
    31  	posVarIndex         int
    32  	dialectType         dialect.DialectType
    33  	MysqlSpecialComment *Scanner
    35  	Pos int
    36  	buf string
    37  }
    39  func NewScanner(dialectType dialect.DialectType, sql string) *Scanner {
    41  	return &Scanner{
    42  		buf: sql,
    43  	}
    44  }
    46  func (s *Scanner) Scan() (int, string) {
    47  	if s.MysqlSpecialComment != nil {
    48  		msc := s.MysqlSpecialComment
    49  		tok, val := msc.Scan()
    50  		if tok != 0 {
    51  			return tok, val
    52  		}
    53  		s.MysqlSpecialComment = nil
    54  	}
    56  	s.skipBlank()
    57  	switch ch := s.cur(); {
    58  	case ch == '@':
    59  		tokenID := AT_ID
    60  		s.skip(1)
    61  		s.skipBlank()
    62  		if s.cur() == '@' {
    63  			tokenID = AT_AT_ID
    64  			s.skip(1)
    65  		} else if s.cur() == '\'' || s.cur() == '"' {
    66  			return int('@'), ""
    67  		} else if s.cur() == ',' {
    68  			return tokenID, ""
    69  		}
    70  		var tID int
    71  		var tBytes string
    72  		if s.cur() == '`' {
    73  			s.skip(1)
    74  			tID, tBytes = s.scanLiteralIdentifier()
    75  		} else if s.cur() == eofChar {
    76  			return LEX_ERROR, ""
    77  		} else {
    78  			tID, tBytes = s.scanIdentifier(true)
    79  		}
    80  		if tID == LEX_ERROR {
    81  			return tID, ""
    82  		}
    83  		return tokenID, tBytes
    84  	case isLetter(ch):
    85  		if ch == 'X' || ch == 'x' {
    86  			if s.peek(1) == '\'' {
    87  				s.skip(2)
    88  				return s.scanHex()
    89  			}
    90  		}
    91  		if ch == 'B' || ch == 'b' {
    92  			if s.peek(1) == '\'' {
    93  				s.skip(2)
    94  				return s.scanBitLiteral()
    95  			}
    96  		}
    97  		return s.scanIdentifier(false)
    98  	case isDigit(ch):
    99  		return s.scanNumber()
   100  	case ch == ':':
   101  		if s.peek(1) == '=' {
   102  			s.skip(2)
   103  			return ASSIGNMENT, ""
   104  		}
   105  		// Like mysql -h ::1 ?
   106  		return s.scanBindVar()
   107  	case ch == ';':
   108  		s.skip(1)
   109  		return ';', ""
   110  	case ch == '.' && isDigit(s.peek(1)):
   111  		return s.scanNumber()
   112  	case ch == '/':
   113  		s.skip(1)
   114  		switch s.cur() {
   115  		case '/':
   116  			s.skip(1)
   117  			id, str := s.scanCommentTypeLine(2)
   118  			if id == LEX_ERROR {
   119  				return id, str
   120  			}
   121  			return s.Scan()
   122  		case '*':
   123  			s.skip(1)
   124  			switch {
   125  			case s.cur() == '!' && s.dialectType == dialect.MYSQL:
   126  				// TODO: ExtractMysqlComment
   127  				return s.scanMySQLSpecificComment()
   128  			default:
   129  				id, str := s.scanCommentTypeBlock()
   130  				if id == LEX_ERROR {
   131  					return id, str
   132  				}
   133  				return s.Scan()
   134  			}
   135  		default:
   136  			return int(ch), ""
   137  		}
   138  	default:
   139  		return s.stepBackOneChar(ch)
   140  	}
   141  }
   143  func (s *Scanner) stepBackOneChar(ch uint16) (int, string) {
   144  	s.skip(1)
   145  	switch ch {
   146  	case eofChar:
   147  		return 0, ""
   148  	case '=', ',', '(', ')', '+', '*', '%', '^', '~':
   149  		return int(ch), ""
   150  	case '&':
   151  		if s.cur() == '&' {
   152  			s.skip(1)
   153  			return AND, ""
   154  		}
   155  		return int(ch), ""
   156  	case '|':
   157  		if s.cur() == '|' {
   158  			s.skip(1)
   159  			return PIPE_CONCAT, ""
   160  		}
   161  		return int(ch), ""
   162  	case '?':
   163  		// mysql's situation
   164  		s.posVarIndex++
   165  		buf := make([]byte, 0, 8)
   166  		buf = append(buf, ":v"...)
   167  		buf = strconv.AppendInt(buf, int64(s.posVarIndex), 10)
   168  		return VALUE_ARG, string(buf)
   169  	case '.':
   170  		return int(ch), ""
   171  	case '#':
   172  		return s.scanCommentTypeLine(1)
   173  	case '-':
   174  		switch s.cur() {
   175  		case '-':
   176  			nextChar := s.peek(1)
   177  			if nextChar == ' ' || nextChar == '\n' || nextChar == '\t' || nextChar == '\r' || nextChar == eofChar {
   178  				s.skip(1)
   179  				return s.scanCommentTypeLine(2)
   180  			}
   181  		case '>':
   182  			s.skip(1)
   183  			// TODO:
   185  			// JSON_EXTRACT_OP
   186  			return 0, ""
   187  		}
   188  		return int(ch), ""
   189  	case '<':
   190  		switch s.cur() {
   191  		case '>':
   192  			s.skip(1)
   193  			return NE, ""
   194  		case '<':
   195  			s.skip(1)
   196  			return SHIFT_LEFT, ""
   197  		case '=':
   198  			s.skip(1)
   199  			switch s.cur() {
   200  			case '>':
   201  				s.skip(1)
   202  				return NULL_SAFE_EQUAL, ""
   203  			default:
   204  				return LE, ""
   205  			}
   206  		default:
   207  			return int(ch), ""
   208  		}
   209  	case '>':
   210  		switch s.cur() {
   211  		case '=':
   212  			s.skip(1)
   213  			return GE, ""
   214  		case '>':
   215  			s.skip(1)
   216  			return SHIFT_RIGHT, ""
   217  		default:
   218  			return int(ch), ""
   219  		}
   220  	case '!':
   221  		if s.cur() == '=' {
   222  			s.skip(1)
   223  			return NE, ""
   224  		}
   225  		return int(ch), ""
   226  	case '\'', '"':
   227  		return s.scanString(ch, STRING)
   228  	case '`':
   229  		return s.scanLiteralIdentifier()
   230  	default:
   231  		return LEX_ERROR, string(byte(ch))
   232  	}
   233  }
   235  // scanString scans a string surrounded by the given `delim`, which can be
   236  // either single or double quotes. Assumes that the given delimiter has just
   237  // been scanned. If the skin contains any escape sequences, this function
   238  // will fall back to scanStringSlow
   239  func (s *Scanner) scanString(delim uint16, typ int) (int, string) {
   240  	start := s.Pos
   242  	for {
   243  		switch s.cur() {
   244  		case delim:
   245  			if s.peek(1) != delim {
   246  				s.skip(1)
   247  				return typ, s.buf[start : s.Pos-1]
   248  			}
   249  			fallthrough
   251  		case '\\':
   252  			var buffer strings.Builder
   253  			buffer.WriteString(s.buf[start:s.Pos])
   254  			return s.scanStringSlow(&buffer, delim, typ)
   256  		case eofChar:
   257  			return LEX_ERROR, s.buf[start:s.Pos]
   258  		}
   260  		s.skip(1)
   261  	}
   262  }
   264  // scanString scans a string surrounded by the given `delim` and containing escape
   265  // sequencse. The given `buffer` contains the contents of the string that have
   266  // been scanned so far.
   267  func (s *Scanner) scanStringSlow(buffer *strings.Builder, delim uint16, typ int) (int, string) {
   268  	for {
   269  		ch := s.cur()
   270  		if ch == eofChar {
   271  			// Unterminated string.
   272  			return LEX_ERROR, buffer.String()
   273  		}
   275  		if ch != delim && ch != '\\' {
   276  			start := s.Pos
   277  			for ; s.Pos < len(s.buf); s.Pos++ {
   278  				ch = uint16(s.buf[s.Pos])
   279  				if ch == delim || ch == '\\' {
   280  					break
   281  				}
   282  			}
   284  			buffer.WriteString(s.buf[start:s.Pos])
   285  			if s.Pos >= len(s.buf) {
   286  				s.skip(1)
   287  				continue
   288  			}
   289  		}
   290  		s.skip(1)
   292  		if ch == '\\' {
   293  			ch = s.cur()
   294  			switch ch {
   295  			case eofChar:
   296  				return LEX_ERROR, buffer.String()
   297  			case 'n':
   298  				ch = '\n'
   299  			case '0':
   300  				ch = '\x00'
   301  			case 'b':
   302  				ch = 8
   303  			case 'Z':
   304  				ch = 26
   305  			case 'r':
   306  				ch = '\r'
   307  			case 't':
   308  				ch = '\t'
   309  			case '%', '_':
   310  				buffer.WriteByte(byte('\\'))
   311  				continue
   312  			case '\\', delim:
   313  			default:
   314  				continue
   315  			}
   316  		} else if ch == delim && s.cur() != delim {
   317  			break
   318  		}
   319  		buffer.WriteByte(byte(ch))
   320  		s.skip(1)
   321  	}
   323  	return typ, buffer.String()
   324  }
   326  // scanLiteralIdentifier scans an identifier enclosed by backticks. If the identifier
   327  // is a simple literal, it'll be returned as a slice of the input buffer. If the identifier
   328  // contains escape sequences, this function will fall back to scanLiteralIdentifierSlow
   329  func (s *Scanner) scanLiteralIdentifier() (int, string) {
   330  	start := s.Pos
   331  	for {
   332  		switch s.cur() {
   333  		case '`':
   334  			if s.peek(1) != '`' {
   335  				if s.Pos == start {
   336  					return LEX_ERROR, ""
   337  				}
   338  				s.skip(1)
   339  				return ID, s.buf[start : s.Pos-1]
   340  			}
   342  			var buf strings.Builder
   343  			buf.WriteString(s.buf[start:s.Pos])
   344  			s.skip(1)
   345  			return s.scanLiteralIdentifierSlow(&buf)
   346  		case eofChar:
   347  			// Premature EOF.
   348  			return LEX_ERROR, s.buf[start:s.Pos]
   349  		default:
   350  			s.skip(1)
   351  		}
   352  	}
   353  }
   355  // scanLiteralIdentifierSlow scans an identifier surrounded by backticks which may
   356  // contain escape sequences instead of it. This method is only called from
   357  // scanLiteralIdentifier once the first escape sequence is found in the identifier.
   358  // The provided `buf` contains the contents of the identifier that have been scanned
   359  // so far.
   360  func (s *Scanner) scanLiteralIdentifierSlow(buf *strings.Builder) (int, string) {
   361  	backTickSeen := true
   362  	for {
   363  		if backTickSeen {
   364  			if s.cur() != '`' {
   365  				break
   366  			}
   367  			backTickSeen = false
   368  			buf.WriteByte('`')
   369  			s.skip(1)
   370  			continue
   371  		}
   372  		// The previous char was not a backtick.
   373  		switch s.cur() {
   374  		case '`':
   375  			backTickSeen = true
   376  		case eofChar:
   377  			// Premature EOF.
   378  			return LEX_ERROR, buf.String()
   379  		default:
   380  			buf.WriteByte(byte(s.cur()))
   381  			// keep scanning
   382  		}
   383  		s.skip(1)
   384  	}
   385  	return ID, buf.String()
   386  }
   388  // scanCommentTypeBlock scans a '/*' delimited comment;
   389  // assumes the opening prefix has already been scanned
   390  func (s *Scanner) scanCommentTypeBlock() (int, string) {
   391  	start := s.Pos - 2
   392  	for {
   393  		if s.cur() == '*' {
   394  			s.skip(1)
   395  			if s.cur() == '/' {
   396  				s.skip(1)
   397  				break
   398  			}
   399  			continue
   400  		}
   401  		if s.cur() == eofChar {
   402  			return LEX_ERROR, s.buf[start:s.Pos]
   403  		}
   404  		s.skip(1)
   405  	}
   406  	return COMMENT, s.buf[start:s.Pos]
   407  }
   409  // scanMySQLSpecificComment scans a MySQL comment pragma, which always starts with '//*`
   410  func (s *Scanner) scanMySQLSpecificComment() (int, string) {
   411  	start := s.Pos - 3
   412  	for {
   413  		if s.cur() == '*' {
   414  			s.skip(1)
   415  			if s.cur() == '/' {
   416  				s.skip(1)
   417  				break
   418  			}
   419  			continue
   420  		}
   421  		if s.cur() == eofChar {
   422  			return LEX_ERROR, s.buf[start:s.Pos]
   423  		}
   424  		s.skip(1)
   425  	}
   427  	_, sql := ExtractMysqlComment(s.buf[start:s.Pos])
   429  	s.MysqlSpecialComment = NewScanner(s.dialectType, sql)
   431  	return s.Scan()
   432  }
   434  // ExtractMysqlComment extracts the version and SQL from a comment-only query
   435  // such as /*!50708 sql here */
   436  func ExtractMysqlComment(sql string) (string, string) {
   437  	sql = sql[3 : len(sql)-2]
   439  	digitCount := 0
   440  	endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
   441  		digitCount++
   442  		return !unicode.IsDigit(c) || digitCount == 6
   443  	})
   444  	if endOfVersionIndex < 0 {
   445  		return "", ""
   446  	}
   447  	if endOfVersionIndex < 5 {
   448  		endOfVersionIndex = 0
   449  	}
   450  	version := sql[0:endOfVersionIndex]
   451  	innerSQL := strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
   453  	return version, innerSQL
   454  }
   456  // scanCommentTypeLine scans a SQL line-comment, which is applied until the end
   457  // of the line. The given prefix length varies based on whether the comment
   458  // is started with '//', '--' or '#'.
   459  func (s *Scanner) scanCommentTypeLine(prefixLen int) (int, string) {
   460  	start := s.Pos - prefixLen
   461  	for s.cur() != eofChar {
   462  		if s.cur() == '\n' {
   463  			s.skip(1)
   464  			break
   465  		}
   466  		s.skip(1)
   467  	}
   468  	return COMMENT, s.buf[start:s.Pos]
   469  }
   471  // ?
   472  // scanBindVar scans a bind variable; assumes a ':' has been scanned right before
   473  func (s *Scanner) scanBindVar() (int, string) {
   474  	start := s.Pos
   475  	token := VALUE_ARG
   477  	s.skip(1)
   478  	if s.cur() == ':' {
   479  		token = LIST_ARG
   480  		s.skip(1)
   481  	}
   482  	if !isLetter(s.cur()) {
   483  		return LEX_ERROR, s.buf[start:s.Pos]
   484  	}
   485  	for {
   486  		ch := s.cur()
   487  		if !isLetter(ch) && !isDigit(ch) && ch != '.' {
   488  			break
   489  		}
   490  		s.skip(1)
   491  	}
   492  	return token, s.buf[start:s.Pos]
   493  }
   495  // scanNumber scans any SQL numeric literal, either floating point or integer
   496  func (s *Scanner) scanNumber() (int, string) {
   497  	start := s.Pos
   498  	token := INTEGRAL
   500  	if s.cur() == '.' {
   501  		token = FLOAT
   502  		s.skip(1)
   503  		s.scanMantissa(10)
   504  		goto exponent
   505  	}
   507  	// 0x construct.
   508  	if s.cur() == '0' {
   509  		s.skip(1)
   510  		if s.cur() == 'x' || s.cur() == 'X' {
   511  			token = HEXNUM
   512  			s.skip(1)
   513  			s.scanMantissa(16)
   514  			goto exit
   515  		} else if s.cur() == 'b' || s.cur() == 'B' {
   516  			token = BIT_LITERAL
   517  			s.skip(1)
   518  			s.scanMantissa(2)
   519  			goto exit
   520  		}
   521  	}
   523  	s.scanMantissa(10)
   525  	if s.cur() == '.' {
   526  		token = FLOAT
   527  		s.skip(1)
   528  		s.scanMantissa(10)
   529  	}
   531  exponent:
   532  	if s.cur() == 'e' || s.cur() == 'E' {
   533  		if s.peek(1) == '+' || s.peek(1) == '-' {
   534  			token = FLOAT
   535  			s.skip(2)
   536  		} else if digitVal(s.peek(1)) < 10 {
   537  			token = FLOAT
   538  			s.skip(1)
   539  		} else {
   540  			goto exit
   541  		}
   542  		s.scanMantissa(10)
   543  	}
   545  exit:
   546  	if isLetter(s.cur()) {
   547  		// TODO: optimize
   548  		token = ID
   549  		s.scanIdentifier(false)
   550  	}
   552  	return token, strings.ToLower(s.buf[start:s.Pos])
   553  }
   555  func (s *Scanner) scanIdentifier(isVariable bool) (int, string) {
   556  	start := s.Pos
   557  	s.skip(1)
   559  	for {
   560  		ch := s.cur()
   561  		if !isLetter(ch) && !isDigit(ch) && ch != '@' && !(isVariable && isCarat(ch)) {
   562  			break
   563  		}
   564  		if ch == '@' {
   565  			isVariable = true
   566  		}
   567  		s.skip(1)
   568  	}
   569  	keywordName := s.buf[start:s.Pos]
   570  	lower := strings.ToLower(keywordName)
   571  	if keywordID, found := keywords[lower]; found {
   572  		return keywordID, lower
   573  	}
   574  	// dual must always be case-insensitive
   575  	if lower == "dual" {
   576  		return ID, lower
   577  	}
   578  	return ID, lower
   579  }
   581  func (s *Scanner) scanBitLiteral() (int, string) {
   582  	start := s.Pos
   583  	s.scanMantissa(2)
   584  	bit := s.buf[start:s.Pos]
   585  	if s.cur() != '\'' {
   586  		return LEX_ERROR, bit
   587  	}
   588  	s.skip(1)
   589  	return BIT_LITERAL, bit
   590  }
   592  func (s *Scanner) scanHex() (int, string) {
   593  	start := s.Pos
   594  	s.scanMantissa(16)
   595  	hex := s.buf[start:s.Pos]
   596  	if s.cur() != '\'' {
   597  		return LEX_ERROR, hex
   598  	}
   599  	s.skip(1)
   600  	if len(hex)%2 != 0 {
   601  		return LEX_ERROR, hex
   602  	}
   603  	return HEXNUM, hex
   604  }
   606  func (s *Scanner) scanMantissa(base int) {
   607  	for digitVal(s.cur()) < base {
   608  		s.skip(1)
   609  	}
   610  }
   612  // PositionedErr holds context related to parser errros
   613  type PositionedErr struct {
   614  	Err  string
   615  	Pos  int
   616  	Near string
   617  }
   619  func (p PositionedErr) Error() string {
   620  	if p.Near != "" {
   621  		return fmt.Sprintf("%s at position %v near '%s';", p.Err, p.Pos, p.Near)
   622  	}
   623  	return fmt.Sprintf("%s at position %v;", p.Err, p.Pos)
   624  }
   626  func (s *Scanner) skipBlank() {
   627  	ch := s.cur()
   628  	for ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t' {
   629  		s.skip(1)
   630  		ch = s.cur()
   631  	}
   632  }
   634  func (s *Scanner) cur() uint16 {
   635  	return s.peek(0)
   636  }
   638  func (s *Scanner) skip(dist int) {
   639  	s.Pos += dist
   640  }
   642  func (s *Scanner) peek(dist int) uint16 {
   643  	if s.Pos+dist >= len(s.buf) {
   644  		return eofChar
   645  	}
   646  	return uint16(s.buf[s.Pos+dist])
   647  }
   649  func isLetter(ch uint16) bool {
   650  	return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' || ch == '$'
   651  }
   653  func isCarat(ch uint16) bool {
   654  	return ch == '.' || ch == '"' || ch == '`' || ch == '\''
   655  }
   657  func digitVal(ch uint16) int {
   658  	switch {
   659  	case '0' <= ch && ch <= '9':
   660  		return int(ch) - '0'
   661  	case 'a' <= ch && ch <= 'f':
   662  		return int(ch) - 'a' + 10
   663  	case 'A' <= ch && ch <= 'F':
   664  		return int(ch) - 'A' + 10
   665  	}
   666  	return 16 // larger than any legal digit val
   667  }
   669  func isDigit(ch uint16) bool {
   670  	return '0' <= ch && ch <= '9'
   671  }