github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/scanner/scan.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package scanner
    12  
    13  import (
    14  	"fmt"
    15  	"go/constant"
    16  	"go/token"
    17  	"strconv"
    18  	"strings"
    19  	"unicode/utf8"
    20  	"unsafe"
    21  
    22  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/lexbase"
    23  )
    24  
    25  const eof = -1
    26  const errUnterminated = "unterminated string"
    27  const errInvalidUTF8 = "invalid UTF-8 byte sequence"
    28  const errInvalidHexNumeric = "invalid hexadecimal numeric literal"
    29  const singleQuote = '\''
    30  const identQuote = '"'
    31  
    32  // NewNumValFn allows us to use tree.NewNumVal without a dependency on tree.
    33  var NewNumValFn = func(constant.Value, string, bool) interface{} {
    34  	return struct{}{}
    35  }
    36  
    37  // NewPlaceholderFn allows us to use tree.NewPlaceholder without a dependency on
    38  // tree.
    39  var NewPlaceholderFn = func(string) (interface{}, error) {
    40  	return struct{}{}, nil
    41  }
    42  
    43  // ScanSymType is the interface for accessing the fields of a yacc symType.
    44  type ScanSymType interface {
    45  	ID() int32
    46  	SetID(int32)
    47  	Pos() int32
    48  	SetPos(int32)
    49  	Str() string
    50  	SetStr(string)
    51  	UnionVal() interface{}
    52  	SetUnionVal(interface{})
    53  }
    54  
    55  // Scanner lexes SQL statements.
    56  type Scanner struct {
    57  	in            string
    58  	pos           int
    59  	bytesPrealloc []byte
    60  
    61  	// Comments is the list of parsed comments from the SQL statement.
    62  	Comments []string
    63  
    64  	// lastAttemptedID indicates the ID of the last attempted
    65  	// token. Used to recognizd which token an error was encountered
    66  	// on.
    67  	lastAttemptedID int32
    68  	// quoted indicates if the last identifier scanned was
    69  	// quoted. Used to distinguish between quoted and non-quoted in
    70  	// Inspect.
    71  	quoted bool
    72  }
    73  
    74  // SQLScanner is a scanner with a SQL specific scan function
    75  type SQLScanner struct {
    76  	Scanner
    77  }
    78  
    79  // In returns the input string.
    80  func (s *Scanner) In() string {
    81  	return s.in
    82  }
    83  
    84  // Pos returns the current position being lexed.
    85  func (s *Scanner) Pos() int {
    86  	return s.pos
    87  }
    88  
    89  // Init initializes a new Scanner that will process str.
    90  func (s *Scanner) Init(str string) {
    91  	s.in = str
    92  	s.pos = 0
    93  	// Preallocate some buffer space for identifiers etc.
    94  	s.bytesPrealloc = make([]byte, len(str))
    95  }
    96  
    97  // Cleanup is used to avoid holding on to memory unnecessarily (for the cases
    98  // where we reuse a Scanner).
    99  func (s *Scanner) Cleanup() {
   100  	s.bytesPrealloc = nil
   101  }
   102  
   103  func (s *Scanner) allocBytes(length int) []byte {
   104  	if cap(s.bytesPrealloc) >= length {
   105  		res := s.bytesPrealloc[:length:length]
   106  		s.bytesPrealloc = s.bytesPrealloc[length:cap(s.bytesPrealloc)]
   107  		return res
   108  	}
   109  	return make([]byte, length)
   110  }
   111  
   112  // buffer returns an empty []byte buffer that can be appended to. Any unused
   113  // portion can be returned later using returnBuffer.
   114  func (s *Scanner) buffer() []byte {
   115  	buf := s.bytesPrealloc[:0]
   116  	s.bytesPrealloc = nil
   117  	return buf
   118  }
   119  
   120  // returnBuffer returns the unused portion of buf to the Scanner, to be used for
   121  // future allocBytes() or buffer() calls. The caller must not use buf again.
   122  func (s *Scanner) returnBuffer(buf []byte) {
   123  	if len(buf) < cap(buf) {
   124  		s.bytesPrealloc = buf[len(buf):]
   125  	}
   126  }
   127  
   128  // finishString casts the given buffer to a string and returns the unused
   129  // portion of the buffer. The caller must not use buf again.
   130  func (s *Scanner) finishString(buf []byte) string {
   131  	str := *(*string)(unsafe.Pointer(&buf))
   132  	s.returnBuffer(buf)
   133  	return str
   134  }
   135  
   136  func (s *Scanner) scanSetup(lval ScanSymType) (int, bool) {
   137  	lval.SetID(0)
   138  	lval.SetPos(int32(s.pos))
   139  	lval.SetStr("EOF")
   140  	s.quoted = false
   141  	s.lastAttemptedID = 0
   142  
   143  	if _, ok := s.skipWhitespace(lval, true); !ok {
   144  		return 0, true
   145  	}
   146  
   147  	ch := s.next()
   148  	if ch == eof {
   149  		lval.SetPos(int32(s.pos))
   150  		return ch, false
   151  	}
   152  
   153  	lval.SetID(int32(ch))
   154  	lval.SetPos(int32(s.pos - 1))
   155  	lval.SetStr(s.in[lval.Pos():s.pos])
   156  	s.lastAttemptedID = int32(ch)
   157  	return ch, false
   158  }
   159  
   160  // Scan scans the next token and populates its information into lval.
   161  func (s *SQLScanner) Scan(lval ScanSymType) {
   162  	ch, skipWhiteSpace := s.scanSetup(lval)
   163  
   164  	if skipWhiteSpace {
   165  		return
   166  	}
   167  
   168  	switch ch {
   169  	case '$':
   170  		// placeholder? $[0-9]+
   171  		if lexbase.IsDigit(s.peek()) {
   172  			s.scanPlaceholder(lval)
   173  			return
   174  		} else if s.scanDollarQuotedString(lval) {
   175  			lval.SetID(lexbase.SCONST)
   176  			return
   177  		}
   178  		return
   179  
   180  	case identQuote:
   181  		// "[^"]"
   182  		s.lastAttemptedID = int32(lexbase.IDENT)
   183  		s.quoted = true
   184  		if s.scanString(lval, identQuote, false /* allowEscapes */, true /* requireUTF8 */) {
   185  			lval.SetID(lexbase.IDENT)
   186  		}
   187  		return
   188  
   189  	case singleQuote:
   190  		// '[^']'
   191  		s.lastAttemptedID = int32(lexbase.SCONST)
   192  		if s.scanString(lval, ch, false /* allowEscapes */, true /* requireUTF8 */) {
   193  			lval.SetID(lexbase.SCONST)
   194  		}
   195  		return
   196  
   197  	case 'b':
   198  		// Bytes?
   199  		if s.peek() == singleQuote {
   200  			// b'[^']'
   201  			s.lastAttemptedID = int32(lexbase.BCONST)
   202  			s.pos++
   203  			if s.scanString(lval, singleQuote, true /* allowEscapes */, false /* requireUTF8 */) {
   204  				lval.SetID(lexbase.BCONST)
   205  			}
   206  			return
   207  		}
   208  		s.scanIdent(lval)
   209  		return
   210  
   211  	case 'r', 'R':
   212  		s.scanIdent(lval)
   213  		return
   214  
   215  	case 'e', 'E':
   216  		// Escaped string?
   217  		if s.peek() == singleQuote {
   218  			// [eE]'[^']'
   219  			s.lastAttemptedID = int32(lexbase.SCONST)
   220  			s.pos++
   221  			if s.scanString(lval, singleQuote, true /* allowEscapes */, true /* requireUTF8 */) {
   222  				lval.SetID(lexbase.SCONST)
   223  			}
   224  			return
   225  		}
   226  		s.scanIdent(lval)
   227  		return
   228  
   229  	case 'B':
   230  		// Bit array literal?
   231  		if s.peek() == singleQuote {
   232  			// B'[01]*'
   233  			s.pos++
   234  			s.scanBitString(lval, singleQuote)
   235  			return
   236  		}
   237  		s.scanIdent(lval)
   238  		return
   239  
   240  	case 'x', 'X':
   241  		// Hex literal?
   242  		if s.peek() == singleQuote {
   243  			// [xX]'[a-f0-9]'
   244  			s.pos++
   245  			s.scanHexString(lval, singleQuote)
   246  			return
   247  		}
   248  		s.scanIdent(lval)
   249  		return
   250  
   251  	case '.':
   252  		switch t := s.peek(); {
   253  		case t == '.': // ..
   254  			s.pos++
   255  			lval.SetID(lexbase.DOT_DOT)
   256  			return
   257  		case lexbase.IsDigit(t):
   258  			s.lastAttemptedID = int32(lexbase.FCONST)
   259  			s.scanNumber(lval, ch)
   260  			return
   261  		}
   262  		return
   263  
   264  	case '!':
   265  		switch s.peek() {
   266  		case '=': // !=
   267  			s.pos++
   268  			lval.SetID(lexbase.NOT_EQUALS)
   269  			return
   270  		case '~': // !~
   271  			s.pos++
   272  			switch s.peek() {
   273  			case '*': // !~*
   274  				s.pos++
   275  				lval.SetID(lexbase.NOT_REGIMATCH)
   276  				return
   277  			}
   278  			lval.SetID(lexbase.NOT_REGMATCH)
   279  			return
   280  		}
   281  		return
   282  
   283  	case '?':
   284  		switch s.peek() {
   285  		case '?': // ??
   286  			s.pos++
   287  			lval.SetID(lexbase.HELPTOKEN)
   288  			return
   289  		case '|': // ?|
   290  			s.pos++
   291  			lval.SetID(lexbase.JSON_SOME_EXISTS)
   292  			return
   293  		case '&': // ?&
   294  			s.pos++
   295  			lval.SetID(lexbase.JSON_ALL_EXISTS)
   296  			return
   297  		}
   298  		return
   299  
   300  	case '<':
   301  		switch s.peek() {
   302  		case '<': // <<
   303  			s.pos++
   304  			switch s.peek() {
   305  			case '=': // <<=
   306  				s.pos++
   307  				lval.SetID(lexbase.INET_CONTAINED_BY_OR_EQUALS)
   308  				return
   309  			}
   310  			lval.SetID(lexbase.LSHIFT)
   311  			return
   312  		case '>': // <>
   313  			s.pos++
   314  			lval.SetID(lexbase.NOT_EQUALS)
   315  			return
   316  		case '=': // <=
   317  			s.pos++
   318  			lval.SetID(lexbase.LESS_EQUALS)
   319  			return
   320  		case '@': // <@
   321  			s.pos++
   322  			lval.SetID(lexbase.CONTAINED_BY)
   323  			return
   324  		}
   325  		return
   326  
   327  	case '>':
   328  		switch s.peek() {
   329  		case '>': // >>
   330  			s.pos++
   331  			switch s.peek() {
   332  			case '=': // >>=
   333  				s.pos++
   334  				lval.SetID(lexbase.INET_CONTAINS_OR_EQUALS)
   335  				return
   336  			}
   337  			lval.SetID(lexbase.RSHIFT)
   338  			return
   339  		case '=': // >=
   340  			s.pos++
   341  			lval.SetID(lexbase.GREATER_EQUALS)
   342  			return
   343  		}
   344  		return
   345  
   346  	case ':':
   347  		switch s.peek() {
   348  		case ':': // ::
   349  			if s.peekN(1) == ':' {
   350  				// :::
   351  				s.pos += 2
   352  				lval.SetID(lexbase.TYPEANNOTATE)
   353  				return
   354  			}
   355  			s.pos++
   356  			lval.SetID(lexbase.TYPECAST)
   357  			return
   358  		}
   359  		return
   360  
   361  	case '|':
   362  		switch s.peek() {
   363  		case '|': // ||
   364  			s.pos++
   365  			switch s.peek() {
   366  			case '/': // ||/
   367  				s.pos++
   368  				lval.SetID(lexbase.CBRT)
   369  				return
   370  			}
   371  			lval.SetID(lexbase.CONCAT)
   372  			return
   373  		case '/': // |/
   374  			s.pos++
   375  			lval.SetID(lexbase.SQRT)
   376  			return
   377  		}
   378  		return
   379  
   380  	case '/':
   381  		switch s.peek() {
   382  		case '/': // //
   383  			s.pos++
   384  			lval.SetID(lexbase.FLOORDIV)
   385  			return
   386  		}
   387  		return
   388  
   389  	case '~':
   390  		switch s.peek() {
   391  		case '*': // ~*
   392  			s.pos++
   393  			lval.SetID(lexbase.REGIMATCH)
   394  			return
   395  		}
   396  		return
   397  
   398  	case '@':
   399  		switch s.peek() {
   400  		case '>': // @>
   401  			s.pos++
   402  			lval.SetID(lexbase.CONTAINS)
   403  			return
   404  		case '@': // @@
   405  			s.pos++
   406  			lval.SetID(lexbase.AT_AT)
   407  			return
   408  		}
   409  		return
   410  
   411  	case '&':
   412  		switch s.peek() {
   413  		case '&': // &&
   414  			s.pos++
   415  			lval.SetID(lexbase.AND_AND)
   416  			return
   417  		}
   418  		return
   419  
   420  	case '-':
   421  		switch s.peek() {
   422  		case '>': // ->
   423  			if s.peekN(1) == '>' {
   424  				// ->>
   425  				s.pos += 2
   426  				lval.SetID(lexbase.FETCHTEXT)
   427  				return
   428  			}
   429  			s.pos++
   430  			lval.SetID(lexbase.FETCHVAL)
   431  			return
   432  		}
   433  		return
   434  
   435  	case '#':
   436  		switch s.peek() {
   437  		case '>': // #>
   438  			if s.peekN(1) == '>' {
   439  				// #>>
   440  				s.pos += 2
   441  				lval.SetID(lexbase.FETCHTEXT_PATH)
   442  				return
   443  			}
   444  			s.pos++
   445  			lval.SetID(lexbase.FETCHVAL_PATH)
   446  			return
   447  		case '-': // #-
   448  			s.pos++
   449  			lval.SetID(lexbase.REMOVE_PATH)
   450  			return
   451  		}
   452  		return
   453  
   454  	default:
   455  		if lexbase.IsDigit(ch) {
   456  			s.lastAttemptedID = int32(lexbase.ICONST)
   457  			s.scanNumber(lval, ch)
   458  			return
   459  		}
   460  		if lexbase.IsIdentStart(ch) {
   461  			s.scanIdent(lval)
   462  			return
   463  		}
   464  	}
   465  
   466  	// Everything else is a single character token which we already initialized
   467  	// lval for above.
   468  }
   469  
   470  func (s *Scanner) peek() int {
   471  	if s.pos >= len(s.in) {
   472  		return eof
   473  	}
   474  	return int(s.in[s.pos])
   475  }
   476  
   477  func (s *Scanner) peekN(n int) int {
   478  	pos := s.pos + n
   479  	if pos >= len(s.in) {
   480  		return eof
   481  	}
   482  	return int(s.in[pos])
   483  }
   484  
   485  func (s *Scanner) next() int {
   486  	ch := s.peek()
   487  	if ch != eof {
   488  		s.pos++
   489  	}
   490  	return ch
   491  }
   492  
   493  func (s *Scanner) skipWhitespace(lval ScanSymType, allowComments bool) (newline, ok bool) {
   494  	newline = false
   495  	for {
   496  		startPos := s.pos
   497  		ch := s.peek()
   498  		if ch == '\n' {
   499  			s.pos++
   500  			newline = true
   501  			continue
   502  		}
   503  		if ch == ' ' || ch == '\t' || ch == '\r' || ch == '\f' {
   504  			s.pos++
   505  			continue
   506  		}
   507  		if allowComments {
   508  			if present, cok := s.ScanComment(lval); !cok {
   509  				return false, false
   510  			} else if present {
   511  				// Mark down the comments that we found.
   512  				s.Comments = append(s.Comments, s.in[startPos:s.pos])
   513  				continue
   514  			}
   515  		}
   516  		break
   517  	}
   518  	return newline, true
   519  }
   520  
   521  // ScanComment scans the input as a comment.
   522  func (s *Scanner) ScanComment(lval ScanSymType) (present, ok bool) {
   523  	start := s.pos
   524  	ch := s.peek()
   525  
   526  	if ch == '/' {
   527  		s.pos++
   528  		if s.peek() != '*' {
   529  			s.pos--
   530  			return false, true
   531  		}
   532  		s.pos++
   533  		depth := 1
   534  		for {
   535  			switch s.next() {
   536  			case '*':
   537  				if s.peek() == '/' {
   538  					s.pos++
   539  					depth--
   540  					if depth == 0 {
   541  						return true, true
   542  					}
   543  					continue
   544  				}
   545  
   546  			case '/':
   547  				if s.peek() == '*' {
   548  					s.pos++
   549  					depth++
   550  					continue
   551  				}
   552  
   553  			case eof:
   554  				lval.SetID(lexbase.ERROR)
   555  				lval.SetPos(int32(start))
   556  				lval.SetStr("unterminated comment")
   557  				return false, false
   558  			}
   559  		}
   560  	}
   561  
   562  	if ch == '-' {
   563  		s.pos++
   564  		if s.peek() != '-' {
   565  			s.pos--
   566  			return false, true
   567  		}
   568  		for {
   569  			switch s.next() {
   570  			case eof, '\n':
   571  				return true, true
   572  			}
   573  		}
   574  	}
   575  
   576  	return false, true
   577  }
   578  
   579  func (s *Scanner) lowerCaseAndNormalizeIdent(lval ScanSymType) {
   580  	s.lastAttemptedID = int32(lexbase.IDENT)
   581  	s.pos--
   582  	start := s.pos
   583  	isASCII := true
   584  	isLower := true
   585  
   586  	// Consume the Scanner character by character, stopping after the last legal
   587  	// identifier character. By the end of this function, we need to
   588  	// lowercase and unicode normalize this identifier, which is expensive if
   589  	// there are actual unicode characters in it. If not, it's quite cheap - and
   590  	// if it's lowercase already, there's no work to do. Therefore, we keep track
   591  	// of whether the string is only ASCII or only ASCII lowercase for later.
   592  	for {
   593  		ch := s.peek()
   594  		if ch >= utf8.RuneSelf {
   595  			isASCII = false
   596  		} else if ch >= 'A' && ch <= 'Z' {
   597  			isLower = false
   598  		}
   599  
   600  		if !lexbase.IsIdentMiddle(ch) {
   601  			break
   602  		}
   603  
   604  		s.pos++
   605  	}
   606  
   607  	if isLower && isASCII {
   608  		// Already lowercased - nothing to do.
   609  		lval.SetStr(s.in[start:s.pos])
   610  	} else if isASCII {
   611  		// We know that the identifier we've seen so far is ASCII, so we don't need
   612  		// to unicode normalize. Instead, just lowercase as normal.
   613  		b := s.allocBytes(s.pos - start)
   614  		_ = b[s.pos-start-1] // For bounds check elimination.
   615  		for i, c := range s.in[start:s.pos] {
   616  			if c >= 'A' && c <= 'Z' {
   617  				c += 'a' - 'A'
   618  			}
   619  			b[i] = byte(c)
   620  		}
   621  		lval.SetStr(*(*string)(unsafe.Pointer(&b)))
   622  	} else {
   623  		// The string has unicode in it. No choice but to run Normalize.
   624  		lval.SetStr(lexbase.NormalizeName(s.in[start:s.pos]))
   625  	}
   626  }
   627  
   628  func (s *Scanner) scanIdent(lval ScanSymType) {
   629  	s.lowerCaseAndNormalizeIdent(lval)
   630  
   631  	isExperimental := false
   632  	kw := lval.Str()
   633  	switch {
   634  	case strings.HasPrefix(lval.Str(), "experimental_"):
   635  		kw = lval.Str()[13:]
   636  		isExperimental = true
   637  	case strings.HasPrefix(lval.Str(), "testing_"):
   638  		kw = lval.Str()[8:]
   639  		isExperimental = true
   640  	}
   641  	lval.SetID(lexbase.GetKeywordID(kw))
   642  	if lval.ID() != lexbase.IDENT {
   643  		if isExperimental {
   644  			if _, ok := lexbase.AllowedExperimental[kw]; !ok {
   645  				// If the parsed token is not on the allowlisted set of keywords,
   646  				// then it might have been intended to be parsed as something else.
   647  				// In that case, re-tokenize the original string.
   648  				lval.SetID(lexbase.GetKeywordID(lval.Str()))
   649  			} else {
   650  				// It is a allowlisted keyword, so remember the shortened
   651  				// keyword for further processing.
   652  				lval.SetStr(kw)
   653  			}
   654  		}
   655  	} else {
   656  		// If the word after experimental_ or testing_ is an identifier,
   657  		// then we might have classified it incorrectly after removing the
   658  		// experimental_/testing_ prefix.
   659  		lval.SetID(lexbase.GetKeywordID(lval.Str()))
   660  	}
   661  }
   662  
   663  func (s *Scanner) scanNumber(lval ScanSymType, ch int) {
   664  	start := s.pos - 1
   665  	isHex := false
   666  	hasDecimal := ch == '.'
   667  	hasExponent := false
   668  
   669  	for {
   670  		ch := s.peek()
   671  		if (isHex && lexbase.IsHexDigit(ch)) || lexbase.IsDigit(ch) {
   672  			s.pos++
   673  			continue
   674  		}
   675  		if ch == 'x' || ch == 'X' {
   676  			if isHex || s.in[start] != '0' || s.pos != start+1 {
   677  				lval.SetID(lexbase.ERROR)
   678  				lval.SetStr(errInvalidHexNumeric)
   679  				return
   680  			}
   681  			s.pos++
   682  			isHex = true
   683  			continue
   684  		}
   685  		if isHex {
   686  			break
   687  		}
   688  		if ch == '.' {
   689  			if hasDecimal || hasExponent {
   690  				break
   691  			}
   692  			s.pos++
   693  			if s.peek() == '.' {
   694  				// Found ".." while scanning a number: back up to the end of the
   695  				// integer.
   696  				s.pos--
   697  				break
   698  			}
   699  			hasDecimal = true
   700  			continue
   701  		}
   702  		if ch == 'e' || ch == 'E' {
   703  			if hasExponent {
   704  				break
   705  			}
   706  			hasExponent = true
   707  			s.pos++
   708  			ch = s.peek()
   709  			if ch == '-' || ch == '+' {
   710  				s.pos++
   711  			}
   712  			ch = s.peek()
   713  			if !lexbase.IsDigit(ch) {
   714  				lval.SetID(lexbase.ERROR)
   715  				lval.SetStr("invalid floating point literal")
   716  				return
   717  			}
   718  			continue
   719  		}
   720  		break
   721  	}
   722  
   723  	// Disallow identifier after numerical constants e.g. "124foo".
   724  	if lexbase.IsIdentStart(s.peek()) {
   725  		lval.SetID(lexbase.ERROR)
   726  		lval.SetStr(fmt.Sprintf("trailing junk after numeric literal at or near %q", s.in[start:s.pos+1]))
   727  		return
   728  	}
   729  
   730  	lval.SetStr(s.in[start:s.pos])
   731  	if hasDecimal || hasExponent {
   732  		lval.SetID(lexbase.FCONST)
   733  		floatConst := constant.MakeFromLiteral(lval.Str(), token.FLOAT, 0)
   734  		if floatConst.Kind() == constant.Unknown {
   735  			lval.SetID(lexbase.ERROR)
   736  			lval.SetStr(fmt.Sprintf("could not make constant float from literal %q", lval.Str()))
   737  			return
   738  		}
   739  		lval.SetUnionVal(NewNumValFn(floatConst, lval.Str(), false /* negative */))
   740  	} else {
   741  		if isHex && s.pos == start+2 {
   742  			lval.SetID(lexbase.ERROR)
   743  			lval.SetStr(errInvalidHexNumeric)
   744  			return
   745  		}
   746  
   747  		// Strip off leading zeros from non-hex (decimal) literals so that
   748  		// constant.MakeFromLiteral doesn't inappropriately interpret the
   749  		// string as an octal literal. Note: we can't use strings.TrimLeft
   750  		// here, because it will truncate '0' to ''.
   751  		if !isHex {
   752  			for len(lval.Str()) > 1 && lval.Str()[0] == '0' {
   753  				lval.SetStr(lval.Str()[1:])
   754  			}
   755  		}
   756  
   757  		lval.SetID(lexbase.ICONST)
   758  		intConst := constant.MakeFromLiteral(lval.Str(), token.INT, 0)
   759  		if intConst.Kind() == constant.Unknown {
   760  			lval.SetID(lexbase.ERROR)
   761  			lval.SetStr(fmt.Sprintf("could not make constant int from literal %q", lval.Str()))
   762  			return
   763  		}
   764  		lval.SetUnionVal(NewNumValFn(intConst, lval.Str(), false /* negative */))
   765  	}
   766  }
   767  
   768  func (s *Scanner) scanPlaceholder(lval ScanSymType) {
   769  	s.lastAttemptedID = int32(lexbase.PLACEHOLDER)
   770  	start := s.pos
   771  	for lexbase.IsDigit(s.peek()) {
   772  		s.pos++
   773  	}
   774  	lval.SetStr(s.in[start:s.pos])
   775  
   776  	placeholder, err := NewPlaceholderFn(lval.Str())
   777  	if err != nil {
   778  		lval.SetID(lexbase.ERROR)
   779  		lval.SetStr(err.Error())
   780  		return
   781  	}
   782  	lval.SetID(lexbase.PLACEHOLDER)
   783  	lval.SetUnionVal(placeholder)
   784  }
   785  
   786  // scanHexString scans the content inside x'....'.
   787  func (s *Scanner) scanHexString(lval ScanSymType, ch int) bool {
   788  	s.lastAttemptedID = int32(lexbase.BCONST)
   789  	buf := s.buffer()
   790  
   791  	var curbyte byte
   792  	bytep := 0
   793  	const errInvalidBytesLiteral = "invalid hexadecimal bytes literal"
   794  outer:
   795  	for {
   796  		b := s.next()
   797  		switch b {
   798  		case ch:
   799  			newline, ok := s.skipWhitespace(lval, false)
   800  			if !ok {
   801  				return false
   802  			}
   803  			// SQL allows joining adjacent strings separated by whitespace
   804  			// as long as that whitespace contains at least one
   805  			// newline. Kind of strange to require the newline, but that
   806  			// is the standard.
   807  			if s.peek() == ch && newline {
   808  				s.pos++
   809  				continue
   810  			}
   811  			break outer
   812  
   813  		case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   814  			curbyte = (curbyte << 4) | byte(b-'0')
   815  		case 'a', 'b', 'c', 'd', 'e', 'f':
   816  			curbyte = (curbyte << 4) | byte(b-'a'+10)
   817  		case 'A', 'B', 'C', 'D', 'E', 'F':
   818  			curbyte = (curbyte << 4) | byte(b-'A'+10)
   819  		default:
   820  			lval.SetID(lexbase.ERROR)
   821  			lval.SetStr(errInvalidBytesLiteral)
   822  			return false
   823  		}
   824  		bytep++
   825  
   826  		if bytep > 1 {
   827  			buf = append(buf, curbyte)
   828  			bytep = 0
   829  			curbyte = 0
   830  		}
   831  	}
   832  
   833  	if bytep != 0 {
   834  		lval.SetID(lexbase.ERROR)
   835  		lval.SetStr(errInvalidBytesLiteral)
   836  		return false
   837  	}
   838  
   839  	lval.SetID(lexbase.BCONST)
   840  	lval.SetStr(s.finishString(buf))
   841  	return true
   842  }
   843  
   844  // scanBitString scans the content inside B'....'.
   845  func (s *Scanner) scanBitString(lval ScanSymType, ch int) bool {
   846  	s.lastAttemptedID = int32(lexbase.BITCONST)
   847  	buf := s.buffer()
   848  outer:
   849  	for {
   850  		b := s.next()
   851  		switch b {
   852  		case ch:
   853  			newline, ok := s.skipWhitespace(lval, false)
   854  			if !ok {
   855  				return false
   856  			}
   857  			// SQL allows joining adjacent strings separated by whitespace
   858  			// as long as that whitespace contains at least one
   859  			// newline. Kind of strange to require the newline, but that
   860  			// is the standard.
   861  			if s.peek() == ch && newline {
   862  				s.pos++
   863  				continue
   864  			}
   865  			break outer
   866  
   867  		case '0', '1':
   868  			buf = append(buf, byte(b))
   869  		default:
   870  			lval.SetID(lexbase.ERROR)
   871  			lval.SetStr(fmt.Sprintf(`"%c" is not a valid binary digit`, rune(b)))
   872  			return false
   873  		}
   874  	}
   875  
   876  	lval.SetID(lexbase.BITCONST)
   877  	lval.SetStr(s.finishString(buf))
   878  	return true
   879  }
   880  
   881  // scanString scans the content inside '...'. This is used for simple
   882  // string literals '...' but also e'....' and b'...'. For x'...', see
   883  // scanHexString().
   884  func (s *Scanner) scanString(lval ScanSymType, ch int, allowEscapes, requireUTF8 bool) bool {
   885  	buf := s.buffer()
   886  	var runeTmp [utf8.UTFMax]byte
   887  	start := s.pos
   888  outer:
   889  	for {
   890  		switch s.next() {
   891  		case ch:
   892  			buf = append(buf, s.in[start:s.pos-1]...)
   893  			if s.peek() == ch {
   894  				// Double quote is translated into a single quote that is part of the
   895  				// string.
   896  				start = s.pos
   897  				s.pos++
   898  				continue
   899  			}
   900  
   901  			newline, ok := s.skipWhitespace(lval, false)
   902  			if !ok {
   903  				return false
   904  			}
   905  
   906  			// SQL allows joining adjacent single-quoted strings separated by
   907  			// whitespace as long as that whitespace contains at least one
   908  			// newline. Kind of strange to require the newline, but that is the
   909  			// standard.
   910  			if ch == singleQuote && s.peek() == singleQuote && newline {
   911  				s.pos++
   912  				start = s.pos
   913  				continue
   914  			}
   915  			break outer
   916  
   917  		case '\\':
   918  			t := s.peek()
   919  
   920  			if allowEscapes {
   921  				buf = append(buf, s.in[start:s.pos-1]...)
   922  				if t == ch {
   923  					start = s.pos
   924  					s.pos++
   925  					continue
   926  				}
   927  
   928  				switch t {
   929  				case 'a', 'b', 'f', 'n', 'r', 't', 'v', 'x', 'X', 'u', 'U', '\\',
   930  					'0', '1', '2', '3', '4', '5', '6', '7':
   931  					var tmp string
   932  					if t == 'X' && len(s.in[s.pos:]) >= 3 {
   933  						// UnquoteChar doesn't handle 'X' so we create a temporary string
   934  						// for it to parse.
   935  						tmp = "\\x" + s.in[s.pos+1:s.pos+3]
   936  					} else {
   937  						tmp = s.in[s.pos-1:]
   938  					}
   939  					v, multibyte, tail, err := strconv.UnquoteChar(tmp, byte(ch))
   940  					if err != nil {
   941  						lval.SetID(lexbase.ERROR)
   942  						lval.SetStr(err.Error())
   943  						return false
   944  					}
   945  					if v < utf8.RuneSelf || !multibyte {
   946  						buf = append(buf, byte(v))
   947  					} else {
   948  						n := utf8.EncodeRune(runeTmp[:], v)
   949  						buf = append(buf, runeTmp[:n]...)
   950  					}
   951  					s.pos += len(tmp) - len(tail) - 1
   952  					start = s.pos
   953  					continue
   954  				}
   955  
   956  				// If we end up here, it's a redundant escape - simply drop the
   957  				// backslash. For example, e'\"' is equivalent to e'"', and
   958  				// e'\d\b' to e'd\b'. This is what Postgres does:
   959  				// http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE
   960  				start = s.pos
   961  			}
   962  
   963  		case eof:
   964  			lval.SetID(lexbase.ERROR)
   965  			lval.SetStr(errUnterminated)
   966  			return false
   967  		}
   968  	}
   969  
   970  	if requireUTF8 && !utf8.Valid(buf) {
   971  		lval.SetID(lexbase.ERROR)
   972  		lval.SetStr(errInvalidUTF8)
   973  		return false
   974  	}
   975  
   976  	if ch == identQuote {
   977  		lval.SetStr(lexbase.NormalizeString(s.finishString(buf)))
   978  	} else {
   979  		lval.SetStr(s.finishString(buf))
   980  	}
   981  	return true
   982  }
   983  
   984  // scanDollarQuotedString scans for so called dollar-quoted strings, which start/end with either $$ or $tag$, where
   985  // tag is some arbitrary string.  e.g. $$a string$$ or $escaped$a string$escaped$.
   986  func (s *Scanner) scanDollarQuotedString(lval ScanSymType) bool {
   987  	s.lastAttemptedID = int32(lexbase.SCONST)
   988  	buf := s.buffer()
   989  	start := s.pos
   990  
   991  	foundStartTag := false
   992  	possibleEndTag := false
   993  	startTagIndex := -1
   994  	var startTag string
   995  
   996  outer:
   997  	for {
   998  		ch := s.peek()
   999  		switch ch {
  1000  		case '$':
  1001  			s.pos++
  1002  			if foundStartTag {
  1003  				if possibleEndTag {
  1004  					if len(startTag) == startTagIndex {
  1005  						// Found end tag.
  1006  						buf = append(buf, s.in[start+len(startTag)+1:s.pos-len(startTag)-2]...)
  1007  						break outer
  1008  					} else {
  1009  						// Was not the end tag but the current $ might be the start of the end tag we are looking for, so
  1010  						// just reset the startTagIndex.
  1011  						startTagIndex = 0
  1012  					}
  1013  				} else {
  1014  					possibleEndTag = true
  1015  					startTagIndex = 0
  1016  				}
  1017  			} else {
  1018  				startTag = s.in[start : s.pos-1]
  1019  				foundStartTag = true
  1020  			}
  1021  
  1022  		case eof:
  1023  			if foundStartTag {
  1024  				// A start tag was found, therefore we expect an end tag before the eof, otherwise it is an error.
  1025  				lval.SetID(lexbase.ERROR)
  1026  				lval.SetStr(errUnterminated)
  1027  			} else {
  1028  				// This is not a dollar-quoted string, reset the pos back to the start.
  1029  				s.pos = start
  1030  			}
  1031  			return false
  1032  
  1033  		default:
  1034  			// If we haven't found a start tag yet, check whether the current characters is a valid for a tag.
  1035  			if !foundStartTag && !lexbase.IsIdentStart(ch) && !lexbase.IsDigit(ch) {
  1036  				return false
  1037  			}
  1038  			s.pos++
  1039  			if possibleEndTag {
  1040  				// Check whether this could be the end tag.
  1041  				if startTagIndex >= len(startTag) || ch != int(startTag[startTagIndex]) {
  1042  					// This is not the end tag we are looking for.
  1043  					possibleEndTag = false
  1044  					startTagIndex = -1
  1045  				} else {
  1046  					startTagIndex++
  1047  				}
  1048  			}
  1049  		}
  1050  	}
  1051  
  1052  	if !utf8.Valid(buf) {
  1053  		lval.SetID(lexbase.ERROR)
  1054  		lval.SetStr(errInvalidUTF8)
  1055  		return false
  1056  	}
  1057  
  1058  	lval.SetStr(s.finishString(buf))
  1059  	return true
  1060  }
  1061  
  1062  // HasMultipleStatements returns true if the sql string contains more than one
  1063  // statements. An error is returned if an invalid token was encountered.
  1064  func HasMultipleStatements(sql string) (multipleStmt bool, err error) {
  1065  	var s SQLScanner
  1066  	var lval fakeSym
  1067  	s.Init(sql)
  1068  	count := 0
  1069  	for {
  1070  		done, hasToks, err := s.scanOne(&lval)
  1071  		if err != nil {
  1072  			return false, err
  1073  		}
  1074  		if hasToks {
  1075  			count++
  1076  		}
  1077  		if done || count > 1 {
  1078  			break
  1079  		}
  1080  	}
  1081  	return count > 1, nil
  1082  }
  1083  
  1084  // scanOne is a simplified version of (*Parser).scanOneStmt() for use
  1085  // by HasMultipleStatements().
  1086  func (s *SQLScanner) scanOne(lval *fakeSym) (done, hasToks bool, err error) {
  1087  	// Scan the first token.
  1088  	for {
  1089  		s.Scan(lval)
  1090  		if lval.id == 0 {
  1091  			return true, false, nil
  1092  		}
  1093  		if lval.id != ';' {
  1094  			break
  1095  		}
  1096  	}
  1097  
  1098  	var preValID int32
  1099  	// This is used to track the degree of nested `BEGIN ATOMIC ... END` function
  1100  	// body context. When greater than zero, it means that we're scanning through
  1101  	// the function body of a `CREATE FUNCTION` statement. ';' character is only
  1102  	// a separator of sql statements within the body instead of a finishing line
  1103  	// of the `CREATE FUNCTION` statement.
  1104  	curFuncBodyCnt := 0
  1105  	for {
  1106  		if lval.id == lexbase.ERROR {
  1107  			return true, true, fmt.Errorf("scan error: %s", lval.s)
  1108  		}
  1109  		preValID = lval.id
  1110  		s.Scan(lval)
  1111  		if preValID == lexbase.BEGIN && lval.id == lexbase.ATOMIC {
  1112  			curFuncBodyCnt++
  1113  		}
  1114  		if curFuncBodyCnt > 0 && lval.id == lexbase.END {
  1115  			curFuncBodyCnt--
  1116  		}
  1117  		if lval.id == 0 || (curFuncBodyCnt == 0 && lval.id == ';') {
  1118  			return (lval.id == 0), true, nil
  1119  		}
  1120  	}
  1121  }
  1122  
  1123  // LastLexicalToken returns the last lexical token. If the string has no lexical
  1124  // tokens, returns 0 and ok=false.
  1125  func LastLexicalToken(sql string) (lastTok int, ok bool) {
  1126  	var s SQLScanner
  1127  	var lval fakeSym
  1128  	s.Init(sql)
  1129  	for {
  1130  		last := lval.ID()
  1131  		s.Scan(&lval)
  1132  		if lval.ID() == 0 {
  1133  			return int(last), last != 0
  1134  		}
  1135  	}
  1136  }
  1137  
  1138  // FirstLexicalToken returns the first lexical token.
  1139  // Returns 0 if there is no token.
  1140  func FirstLexicalToken(sql string) (tok int) {
  1141  	var s SQLScanner
  1142  	var lval fakeSym
  1143  	s.Init(sql)
  1144  	s.Scan(&lval)
  1145  	id := lval.ID()
  1146  	return int(id)
  1147  }
  1148  
  1149  // fakeSym is a simplified symbol type for use by
  1150  // HasMultipleStatements.
  1151  type fakeSym struct {
  1152  	id  int32
  1153  	pos int32
  1154  	s   string
  1155  }
  1156  
  1157  var _ ScanSymType = (*fakeSym)(nil)
  1158  
  1159  func (s fakeSym) ID() int32                 { return s.id }
  1160  func (s *fakeSym) SetID(id int32)           { s.id = id }
  1161  func (s fakeSym) Pos() int32                { return s.pos }
  1162  func (s *fakeSym) SetPos(p int32)           { s.pos = p }
  1163  func (s fakeSym) Str() string               { return s.s }
  1164  func (s *fakeSym) SetStr(v string)          { s.s = v }
  1165  func (s fakeSym) UnionVal() interface{}     { return nil }
  1166  func (s fakeSym) SetUnionVal(v interface{}) {}
  1167  
  1168  // InspectToken is the type of token that can be scanned by Inspect.
  1169  type InspectToken struct {
  1170  	ID      int32
  1171  	MaybeID int32
  1172  	Start   int32
  1173  	End     int32
  1174  	Str     string
  1175  	Quoted  bool
  1176  }
  1177  
  1178  // Inspect analyses the string and returns the tokens found in it. If
  1179  // an incomplete token was encountered at the end, an InspectToken
  1180  // entry with ID -1 is appended.
  1181  //
  1182  // If a syntax error was encountered, it is returned as a token with
  1183  // type ERROR.
  1184  //
  1185  // See TestInspect and the examples in testdata/inspect for more details.
  1186  func Inspect(sql string) []InspectToken {
  1187  	var s SQLScanner
  1188  	var lval fakeSym
  1189  	var tokens []InspectToken
  1190  	s.Init(sql)
  1191  	for {
  1192  		s.Scan(&lval)
  1193  		tok := InspectToken{
  1194  			ID:      lval.id,
  1195  			MaybeID: s.lastAttemptedID,
  1196  			Str:     lval.s,
  1197  			Start:   lval.pos,
  1198  			End:     int32(s.pos),
  1199  			Quoted:  s.quoted,
  1200  		}
  1201  
  1202  		// A special affordance for unterminated quoted identifiers: try
  1203  		// to find the normalized text of the identifier found so far.
  1204  		if lval.id == lexbase.ERROR && s.lastAttemptedID == lexbase.IDENT && s.quoted {
  1205  			maybeIdent := sql[tok.Start:tok.End] + "\""
  1206  			var si SQLScanner
  1207  			si.Init(maybeIdent)
  1208  			si.Scan(&lval)
  1209  			if lval.id == lexbase.IDENT {
  1210  				tok.Str = lval.s
  1211  			}
  1212  		}
  1213  
  1214  		tokens = append(tokens, tok)
  1215  		if lval.id == 0 || lval.id == lexbase.ERROR {
  1216  			return tokens
  1217  		}
  1218  	}
  1219  }