github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/parser/scan.go (about)

     1  // Copyright 2015 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package parser
    12  
    13  import (
    14  	"fmt"
    15  	"go/constant"
    16  	"go/token"
    17  	"strconv"
    18  	"strings"
    19  	"unicode/utf8"
    20  	"unsafe"
    21  
    22  	"github.com/cockroachdb/cockroach/pkg/sql/lex"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    24  )
    25  
    26  const eof = -1
    27  const errUnterminated = "unterminated string"
    28  const errInvalidUTF8 = "invalid UTF-8 byte sequence"
    29  const errInvalidHexNumeric = "invalid hexadecimal numeric literal"
    30  const singleQuote = '\''
    31  const identQuote = '"'
    32  
    33  // scanner lexes SQL statements.
    34  type scanner struct {
    35  	in            string
    36  	pos           int
    37  	bytesPrealloc []byte
    38  }
    39  
    40  func makeScanner(str string) scanner {
    41  	var s scanner
    42  	s.init(str)
    43  	return s
    44  }
    45  
    46  func (s *scanner) init(str string) {
    47  	s.in = str
    48  	s.pos = 0
    49  	// Preallocate some buffer space for identifiers etc.
    50  	s.bytesPrealloc = make([]byte, len(str))
    51  }
    52  
    53  // cleanup is used to avoid holding on to memory unnecessarily (for the cases
    54  // where we reuse a scanner).
    55  func (s *scanner) cleanup() {
    56  	s.bytesPrealloc = nil
    57  }
    58  
    59  func (s *scanner) allocBytes(length int) []byte {
    60  	if len(s.bytesPrealloc) >= length {
    61  		res := s.bytesPrealloc[:length:length]
    62  		s.bytesPrealloc = s.bytesPrealloc[length:]
    63  		return res
    64  	}
    65  	return make([]byte, length)
    66  }
    67  
    68  // buffer returns an empty []byte buffer that can be appended to. Any unused
    69  // portion can be returned later using returnBuffer.
    70  func (s *scanner) buffer() []byte {
    71  	buf := s.bytesPrealloc[:0]
    72  	s.bytesPrealloc = nil
    73  	return buf
    74  }
    75  
    76  // returnBuffer returns the unused portion of buf to the scanner, to be used for
    77  // future allocBytes() or buffer() calls. The caller must not use buf again.
    78  func (s *scanner) returnBuffer(buf []byte) {
    79  	if len(buf) < cap(buf) {
    80  		s.bytesPrealloc = buf[len(buf):]
    81  	}
    82  }
    83  
    84  // finishString casts the given buffer to a string and returns the unused
    85  // portion of the buffer. The caller must not use buf again.
    86  func (s *scanner) finishString(buf []byte) string {
    87  	str := *(*string)(unsafe.Pointer(&buf))
    88  	s.returnBuffer(buf)
    89  	return str
    90  }
    91  
    92  func (s *scanner) scan(lval *sqlSymType) {
    93  	lval.id = 0
    94  	lval.pos = int32(s.pos)
    95  	lval.str = "EOF"
    96  
    97  	if _, ok := s.skipWhitespace(lval, true); !ok {
    98  		return
    99  	}
   100  
   101  	ch := s.next()
   102  	if ch == eof {
   103  		lval.pos = int32(s.pos)
   104  		return
   105  	}
   106  
   107  	lval.id = int32(ch)
   108  	lval.pos = int32(s.pos - 1)
   109  	lval.str = s.in[lval.pos:s.pos]
   110  
   111  	switch ch {
   112  	case '$':
   113  		// placeholder? $[0-9]+
   114  		if lex.IsDigit(s.peek()) {
   115  			s.scanPlaceholder(lval)
   116  			return
   117  		} else if s.scanDollarQuotedString(lval) {
   118  			lval.id = SCONST
   119  			return
   120  		}
   121  		return
   122  
   123  	case identQuote:
   124  		// "[^"]"
   125  		if s.scanString(lval, identQuote, false /* allowEscapes */, true /* requireUTF8 */) {
   126  			lval.id = IDENT
   127  		}
   128  		return
   129  
   130  	case singleQuote:
   131  		// '[^']'
   132  		if s.scanString(lval, ch, false /* allowEscapes */, true /* requireUTF8 */) {
   133  			lval.id = SCONST
   134  		}
   135  		return
   136  
   137  	case 'b':
   138  		// Bytes?
   139  		if s.peek() == singleQuote {
   140  			// b'[^']'
   141  			s.pos++
   142  			if s.scanString(lval, singleQuote, true /* allowEscapes */, false /* requireUTF8 */) {
   143  				lval.id = BCONST
   144  			}
   145  			return
   146  		}
   147  		s.scanIdent(lval)
   148  		return
   149  
   150  	case 'r', 'R':
   151  		s.scanIdent(lval)
   152  		return
   153  
   154  	case 'e', 'E':
   155  		// Escaped string?
   156  		if s.peek() == singleQuote {
   157  			// [eE]'[^']'
   158  			s.pos++
   159  			if s.scanString(lval, singleQuote, true /* allowEscapes */, true /* requireUTF8 */) {
   160  				lval.id = SCONST
   161  			}
   162  			return
   163  		}
   164  		s.scanIdent(lval)
   165  		return
   166  
   167  	case 'B':
   168  		// Bit array literal?
   169  		if s.peek() == singleQuote {
   170  			// B'[01]*'
   171  			s.pos++
   172  			s.scanBitString(lval, singleQuote)
   173  			return
   174  		}
   175  		s.scanIdent(lval)
   176  		return
   177  
   178  	case 'x', 'X':
   179  		// Hex literal?
   180  		if s.peek() == singleQuote {
   181  			// [xX]'[a-f0-9]'
   182  			s.pos++
   183  			s.scanHexString(lval, singleQuote)
   184  			return
   185  		}
   186  		s.scanIdent(lval)
   187  		return
   188  
   189  	case '.':
   190  		switch t := s.peek(); {
   191  		case t == '.': // ..
   192  			s.pos++
   193  			lval.id = DOT_DOT
   194  			return
   195  		case lex.IsDigit(t):
   196  			s.scanNumber(lval, ch)
   197  			return
   198  		}
   199  		return
   200  
   201  	case '!':
   202  		switch s.peek() {
   203  		case '=': // !=
   204  			s.pos++
   205  			lval.id = NOT_EQUALS
   206  			return
   207  		case '~': // !~
   208  			s.pos++
   209  			switch s.peek() {
   210  			case '*': // !~*
   211  				s.pos++
   212  				lval.id = NOT_REGIMATCH
   213  				return
   214  			}
   215  			lval.id = NOT_REGMATCH
   216  			return
   217  		}
   218  		return
   219  
   220  	case '?':
   221  		switch s.peek() {
   222  		case '?': // ??
   223  			s.pos++
   224  			lval.id = HELPTOKEN
   225  			return
   226  		case '|': // ?|
   227  			s.pos++
   228  			lval.id = JSON_SOME_EXISTS
   229  			return
   230  		case '&': // ?&
   231  			s.pos++
   232  			lval.id = JSON_ALL_EXISTS
   233  			return
   234  		}
   235  		return
   236  
   237  	case '<':
   238  		switch s.peek() {
   239  		case '<': // <<
   240  			s.pos++
   241  			switch s.peek() {
   242  			case '=': // <<=
   243  				s.pos++
   244  				lval.id = INET_CONTAINED_BY_OR_EQUALS
   245  				return
   246  			}
   247  			lval.id = LSHIFT
   248  			return
   249  		case '>': // <>
   250  			s.pos++
   251  			lval.id = NOT_EQUALS
   252  			return
   253  		case '=': // <=
   254  			s.pos++
   255  			lval.id = LESS_EQUALS
   256  			return
   257  		case '@': // <@
   258  			s.pos++
   259  			lval.id = CONTAINED_BY
   260  			return
   261  		}
   262  		return
   263  
   264  	case '>':
   265  		switch s.peek() {
   266  		case '>': // >>
   267  			s.pos++
   268  			switch s.peek() {
   269  			case '=': // >>=
   270  				s.pos++
   271  				lval.id = INET_CONTAINS_OR_EQUALS
   272  				return
   273  			}
   274  			lval.id = RSHIFT
   275  			return
   276  		case '=': // >=
   277  			s.pos++
   278  			lval.id = GREATER_EQUALS
   279  			return
   280  		}
   281  		return
   282  
   283  	case ':':
   284  		switch s.peek() {
   285  		case ':': // ::
   286  			if s.peekN(1) == ':' {
   287  				// :::
   288  				s.pos += 2
   289  				lval.id = TYPEANNOTATE
   290  				return
   291  			}
   292  			s.pos++
   293  			lval.id = TYPECAST
   294  			return
   295  		}
   296  		return
   297  
   298  	case '|':
   299  		switch s.peek() {
   300  		case '|': // ||
   301  			s.pos++
   302  			switch s.peek() {
   303  			case '/': // ||/
   304  				s.pos++
   305  				lval.id = CBRT
   306  				return
   307  			}
   308  			lval.id = CONCAT
   309  			return
   310  		case '/': // |/
   311  			s.pos++
   312  			lval.id = SQRT
   313  			return
   314  		}
   315  		return
   316  
   317  	case '/':
   318  		switch s.peek() {
   319  		case '/': // //
   320  			s.pos++
   321  			lval.id = FLOORDIV
   322  			return
   323  		}
   324  		return
   325  
   326  	case '~':
   327  		switch s.peek() {
   328  		case '*': // ~*
   329  			s.pos++
   330  			lval.id = REGIMATCH
   331  			return
   332  		}
   333  		return
   334  
   335  	case '@':
   336  		switch s.peek() {
   337  		case '>': // @>
   338  			s.pos++
   339  			lval.id = CONTAINS
   340  			return
   341  		}
   342  		return
   343  
   344  	case '&':
   345  		switch s.peek() {
   346  		case '&': // &&
   347  			s.pos++
   348  			lval.id = AND_AND
   349  			return
   350  		}
   351  		return
   352  
   353  	case '-':
   354  		switch s.peek() {
   355  		case '>': // ->
   356  			if s.peekN(1) == '>' {
   357  				// ->>
   358  				s.pos += 2
   359  				lval.id = FETCHTEXT
   360  				return
   361  			}
   362  			s.pos++
   363  			lval.id = FETCHVAL
   364  			return
   365  		}
   366  		return
   367  
   368  	case '#':
   369  		switch s.peek() {
   370  		case '>': // #>
   371  			if s.peekN(1) == '>' {
   372  				// #>>
   373  				s.pos += 2
   374  				lval.id = FETCHTEXT_PATH
   375  				return
   376  			}
   377  			s.pos++
   378  			lval.id = FETCHVAL_PATH
   379  			return
   380  		case '-': // #-
   381  			s.pos++
   382  			lval.id = REMOVE_PATH
   383  			return
   384  		}
   385  		return
   386  
   387  	default:
   388  		if lex.IsDigit(ch) {
   389  			s.scanNumber(lval, ch)
   390  			return
   391  		}
   392  		if lex.IsIdentStart(ch) {
   393  			s.scanIdent(lval)
   394  			return
   395  		}
   396  	}
   397  
   398  	// Everything else is a single character token which we already initialized
   399  	// lval for above.
   400  }
   401  
   402  func (s *scanner) peek() int {
   403  	if s.pos >= len(s.in) {
   404  		return eof
   405  	}
   406  	return int(s.in[s.pos])
   407  }
   408  
   409  func (s *scanner) peekN(n int) int {
   410  	pos := s.pos + n
   411  	if pos >= len(s.in) {
   412  		return eof
   413  	}
   414  	return int(s.in[pos])
   415  }
   416  
   417  func (s *scanner) next() int {
   418  	ch := s.peek()
   419  	if ch != eof {
   420  		s.pos++
   421  	}
   422  	return ch
   423  }
   424  
   425  func (s *scanner) skipWhitespace(lval *sqlSymType, allowComments bool) (newline, ok bool) {
   426  	newline = false
   427  	for {
   428  		ch := s.peek()
   429  		if ch == '\n' {
   430  			s.pos++
   431  			newline = true
   432  			continue
   433  		}
   434  		if ch == ' ' || ch == '\t' || ch == '\r' || ch == '\f' {
   435  			s.pos++
   436  			continue
   437  		}
   438  		if allowComments {
   439  			if present, cok := s.scanComment(lval); !cok {
   440  				return false, false
   441  			} else if present {
   442  				continue
   443  			}
   444  		}
   445  		break
   446  	}
   447  	return newline, true
   448  }
   449  
   450  func (s *scanner) scanComment(lval *sqlSymType) (present, ok bool) {
   451  	start := s.pos
   452  	ch := s.peek()
   453  
   454  	if ch == '/' {
   455  		s.pos++
   456  		if s.peek() != '*' {
   457  			s.pos--
   458  			return false, true
   459  		}
   460  		s.pos++
   461  		depth := 1
   462  		for {
   463  			switch s.next() {
   464  			case '*':
   465  				if s.peek() == '/' {
   466  					s.pos++
   467  					depth--
   468  					if depth == 0 {
   469  						return true, true
   470  					}
   471  					continue
   472  				}
   473  
   474  			case '/':
   475  				if s.peek() == '*' {
   476  					s.pos++
   477  					depth++
   478  					continue
   479  				}
   480  
   481  			case eof:
   482  				lval.id = ERROR
   483  				lval.pos = int32(start)
   484  				lval.str = "unterminated comment"
   485  				return false, false
   486  			}
   487  		}
   488  	}
   489  
   490  	if ch == '-' {
   491  		s.pos++
   492  		if s.peek() != '-' {
   493  			s.pos--
   494  			return false, true
   495  		}
   496  		for {
   497  			switch s.next() {
   498  			case eof, '\n':
   499  				return true, true
   500  			}
   501  		}
   502  	}
   503  
   504  	return false, true
   505  }
   506  
   507  func (s *scanner) scanIdent(lval *sqlSymType) {
   508  	s.pos--
   509  	start := s.pos
   510  	isASCII := true
   511  	isLower := true
   512  
   513  	// Consume the scanner character by character, stopping after the last legal
   514  	// identifier character. By the end of this function, we need to
   515  	// lowercase and unicode normalize this identifier, which is expensive if
   516  	// there are actual unicode characters in it. If not, it's quite cheap - and
   517  	// if it's lowercase already, there's no work to do. Therefore, we keep track
   518  	// of whether the string is only ASCII or only ASCII lowercase for later.
   519  	for {
   520  		ch := s.peek()
   521  		//fmt.Println(ch, ch >= utf8.RuneSelf, ch >= 'A' && ch <= 'Z')
   522  
   523  		if ch >= utf8.RuneSelf {
   524  			isASCII = false
   525  		} else if ch >= 'A' && ch <= 'Z' {
   526  			isLower = false
   527  		}
   528  
   529  		if !lex.IsIdentMiddle(ch) {
   530  			break
   531  		}
   532  
   533  		s.pos++
   534  	}
   535  	//fmt.Println("parsed: ", s.in[start:s.pos], isASCII, isLower)
   536  
   537  	if isLower {
   538  		// Already lowercased - nothing to do.
   539  		lval.str = s.in[start:s.pos]
   540  	} else if isASCII {
   541  		// We know that the identifier we've seen so far is ASCII, so we don't need
   542  		// to unicode normalize. Instead, just lowercase as normal.
   543  		b := s.allocBytes(s.pos - start)
   544  		_ = b[s.pos-start-1] // For bounds check elimination.
   545  		for i, c := range s.in[start:s.pos] {
   546  			if c >= 'A' && c <= 'Z' {
   547  				c += 'a' - 'A'
   548  			}
   549  			b[i] = byte(c)
   550  		}
   551  		lval.str = *(*string)(unsafe.Pointer(&b))
   552  	} else {
   553  		// The string has unicode in it. No choice but to run Normalize.
   554  		lval.str = lex.NormalizeName(s.in[start:s.pos])
   555  	}
   556  
   557  	isExperimental := false
   558  	kw := lval.str
   559  	switch {
   560  	case strings.HasPrefix(lval.str, "experimental_"):
   561  		kw = lval.str[13:]
   562  		isExperimental = true
   563  	case strings.HasPrefix(lval.str, "testing_"):
   564  		kw = lval.str[8:]
   565  		isExperimental = true
   566  	}
   567  	lval.id = lex.GetKeywordID(kw)
   568  	if lval.id != lex.IDENT {
   569  		if isExperimental {
   570  			if _, ok := lex.AllowedExperimental[kw]; !ok {
   571  				// If the parsed token is not on the whitelisted set of keywords,
   572  				// then it might have been intended to be parsed as something else.
   573  				// In that case, re-tokenize the original string.
   574  				lval.id = lex.GetKeywordID(lval.str)
   575  			} else {
   576  				// It is a whitelisted keyword, so remember the shortened
   577  				// keyword for further processing.
   578  				lval.str = kw
   579  			}
   580  		}
   581  	} else {
   582  		// If the word after experimental_ or testing_ is an identifier,
   583  		// then we might have classified it incorrectly after removing the
   584  		// experimental_/testing_ prefix.
   585  		lval.id = lex.GetKeywordID(lval.str)
   586  	}
   587  }
   588  
   589  func (s *scanner) scanNumber(lval *sqlSymType, ch int) {
   590  	start := s.pos - 1
   591  	isHex := false
   592  	hasDecimal := ch == '.'
   593  	hasExponent := false
   594  
   595  	for {
   596  		ch := s.peek()
   597  		if (isHex && lex.IsHexDigit(ch)) || lex.IsDigit(ch) {
   598  			s.pos++
   599  			continue
   600  		}
   601  		if ch == 'x' || ch == 'X' {
   602  			if isHex || s.in[start] != '0' || s.pos != start+1 {
   603  				lval.id = ERROR
   604  				lval.str = errInvalidHexNumeric
   605  				return
   606  			}
   607  			s.pos++
   608  			isHex = true
   609  			continue
   610  		}
   611  		if isHex {
   612  			break
   613  		}
   614  		if ch == '.' {
   615  			if hasDecimal || hasExponent {
   616  				break
   617  			}
   618  			s.pos++
   619  			if s.peek() == '.' {
   620  				// Found ".." while scanning a number: back up to the end of the
   621  				// integer.
   622  				s.pos--
   623  				break
   624  			}
   625  			hasDecimal = true
   626  			continue
   627  		}
   628  		if ch == 'e' || ch == 'E' {
   629  			if hasExponent {
   630  				break
   631  			}
   632  			hasExponent = true
   633  			s.pos++
   634  			ch = s.peek()
   635  			if ch == '-' || ch == '+' {
   636  				s.pos++
   637  			}
   638  			ch = s.peek()
   639  			if !lex.IsDigit(ch) {
   640  				lval.id = ERROR
   641  				lval.str = "invalid floating point literal"
   642  				return
   643  			}
   644  			continue
   645  		}
   646  		break
   647  	}
   648  
   649  	lval.str = s.in[start:s.pos]
   650  	if hasDecimal || hasExponent {
   651  		lval.id = FCONST
   652  		floatConst := constant.MakeFromLiteral(lval.str, token.FLOAT, 0)
   653  		if floatConst.Kind() == constant.Unknown {
   654  			lval.id = ERROR
   655  			lval.str = fmt.Sprintf("could not make constant float from literal %q", lval.str)
   656  			return
   657  		}
   658  		lval.union.val = tree.NewNumVal(floatConst, lval.str, false /* negative */)
   659  	} else {
   660  		if isHex && s.pos == start+2 {
   661  			lval.id = ERROR
   662  			lval.str = errInvalidHexNumeric
   663  			return
   664  		}
   665  
   666  		// Strip off leading zeros from non-hex (decimal) literals so that
   667  		// constant.MakeFromLiteral doesn't inappropriately interpret the
   668  		// string as an octal literal. Note: we can't use strings.TrimLeft
   669  		// here, because it will truncate '0' to ''.
   670  		if !isHex {
   671  			for len(lval.str) > 1 && lval.str[0] == '0' {
   672  				lval.str = lval.str[1:]
   673  			}
   674  		}
   675  
   676  		lval.id = ICONST
   677  		intConst := constant.MakeFromLiteral(lval.str, token.INT, 0)
   678  		if intConst.Kind() == constant.Unknown {
   679  			lval.id = ERROR
   680  			lval.str = fmt.Sprintf("could not make constant int from literal %q", lval.str)
   681  			return
   682  		}
   683  		lval.union.val = tree.NewNumVal(intConst, lval.str, false /* negative */)
   684  	}
   685  }
   686  
   687  func (s *scanner) scanPlaceholder(lval *sqlSymType) {
   688  	start := s.pos
   689  	for lex.IsDigit(s.peek()) {
   690  		s.pos++
   691  	}
   692  	lval.str = s.in[start:s.pos]
   693  
   694  	placeholder, err := tree.NewPlaceholder(lval.str)
   695  	if err != nil {
   696  		lval.id = ERROR
   697  		lval.str = err.Error()
   698  		return
   699  	}
   700  	lval.id = PLACEHOLDER
   701  	lval.union.val = placeholder
   702  }
   703  
   704  // scanHexString scans the content inside x'....'.
   705  func (s *scanner) scanHexString(lval *sqlSymType, ch int) bool {
   706  	buf := s.buffer()
   707  
   708  	var curbyte byte
   709  	bytep := 0
   710  	const errInvalidBytesLiteral = "invalid hexadecimal bytes literal"
   711  outer:
   712  	for {
   713  		b := s.next()
   714  		switch b {
   715  		case ch:
   716  			newline, ok := s.skipWhitespace(lval, false)
   717  			if !ok {
   718  				return false
   719  			}
   720  			// SQL allows joining adjacent strings separated by whitespace
   721  			// as long as that whitespace contains at least one
   722  			// newline. Kind of strange to require the newline, but that
   723  			// is the standard.
   724  			if s.peek() == ch && newline {
   725  				s.pos++
   726  				continue
   727  			}
   728  			break outer
   729  
   730  		case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
   731  			curbyte = (curbyte << 4) | byte(b-'0')
   732  		case 'a', 'b', 'c', 'd', 'e', 'f':
   733  			curbyte = (curbyte << 4) | byte(b-'a'+10)
   734  		case 'A', 'B', 'C', 'D', 'E', 'F':
   735  			curbyte = (curbyte << 4) | byte(b-'A'+10)
   736  		default:
   737  			lval.id = ERROR
   738  			lval.str = errInvalidBytesLiteral
   739  			return false
   740  		}
   741  		bytep++
   742  
   743  		if bytep > 1 {
   744  			buf = append(buf, curbyte)
   745  			bytep = 0
   746  			curbyte = 0
   747  		}
   748  	}
   749  
   750  	if bytep != 0 {
   751  		lval.id = ERROR
   752  		lval.str = errInvalidBytesLiteral
   753  		return false
   754  	}
   755  
   756  	lval.id = BCONST
   757  	lval.str = s.finishString(buf)
   758  	return true
   759  }
   760  
   761  // scanBitString scans the content inside B'....'.
   762  func (s *scanner) scanBitString(lval *sqlSymType, ch int) bool {
   763  	buf := s.buffer()
   764  outer:
   765  	for {
   766  		b := s.next()
   767  		switch b {
   768  		case ch:
   769  			newline, ok := s.skipWhitespace(lval, false)
   770  			if !ok {
   771  				return false
   772  			}
   773  			// SQL allows joining adjacent strings separated by whitespace
   774  			// as long as that whitespace contains at least one
   775  			// newline. Kind of strange to require the newline, but that
   776  			// is the standard.
   777  			if s.peek() == ch && newline {
   778  				s.pos++
   779  				continue
   780  			}
   781  			break outer
   782  
   783  		case '0', '1':
   784  			buf = append(buf, byte(b))
   785  		default:
   786  			lval.id = ERROR
   787  			lval.str = fmt.Sprintf(`"%c" is not a valid binary digit`, rune(b))
   788  			return false
   789  		}
   790  	}
   791  
   792  	lval.id = BITCONST
   793  	lval.str = s.finishString(buf)
   794  	return true
   795  }
   796  
   797  // scanString scans the content inside '...'. This is used for simple
   798  // string literals '...' but also e'....' and b'...'. For x'...', see
   799  // scanHexString().
   800  func (s *scanner) scanString(lval *sqlSymType, ch int, allowEscapes, requireUTF8 bool) bool {
   801  	buf := s.buffer()
   802  	var runeTmp [utf8.UTFMax]byte
   803  	start := s.pos
   804  
   805  outer:
   806  	for {
   807  		switch s.next() {
   808  		case ch:
   809  			buf = append(buf, s.in[start:s.pos-1]...)
   810  			if s.peek() == ch {
   811  				// Double quote is translated into a single quote that is part of the
   812  				// string.
   813  				start = s.pos
   814  				s.pos++
   815  				continue
   816  			}
   817  
   818  			newline, ok := s.skipWhitespace(lval, false)
   819  			if !ok {
   820  				return false
   821  			}
   822  			// SQL allows joining adjacent strings separated by whitespace
   823  			// as long as that whitespace contains at least one
   824  			// newline. Kind of strange to require the newline, but that
   825  			// is the standard.
   826  			if s.peek() == ch && newline {
   827  				s.pos++
   828  				start = s.pos
   829  				continue
   830  			}
   831  			break outer
   832  
   833  		case '\\':
   834  			t := s.peek()
   835  
   836  			if allowEscapes {
   837  				buf = append(buf, s.in[start:s.pos-1]...)
   838  				if t == ch {
   839  					start = s.pos
   840  					s.pos++
   841  					continue
   842  				}
   843  
   844  				switch t {
   845  				case 'a', 'b', 'f', 'n', 'r', 't', 'v', 'x', 'X', 'u', 'U', '\\',
   846  					'0', '1', '2', '3', '4', '5', '6', '7':
   847  					var tmp string
   848  					if t == 'X' && len(s.in[s.pos:]) >= 3 {
   849  						// UnquoteChar doesn't handle 'X' so we create a temporary string
   850  						// for it to parse.
   851  						tmp = "\\x" + s.in[s.pos+1:s.pos+3]
   852  					} else {
   853  						tmp = s.in[s.pos-1:]
   854  					}
   855  					v, multibyte, tail, err := strconv.UnquoteChar(tmp, byte(ch))
   856  					if err != nil {
   857  						lval.id = ERROR
   858  						lval.str = err.Error()
   859  						return false
   860  					}
   861  					if v < utf8.RuneSelf || !multibyte {
   862  						buf = append(buf, byte(v))
   863  					} else {
   864  						n := utf8.EncodeRune(runeTmp[:], v)
   865  						buf = append(buf, runeTmp[:n]...)
   866  					}
   867  					s.pos += len(tmp) - len(tail) - 1
   868  					start = s.pos
   869  					continue
   870  				}
   871  
   872  				// If we end up here, it's a redundant escape - simply drop the
   873  				// backslash. For example, e'\"' is equivalent to e'"', and
   874  				// e'\d\b' to e'd\b'. This is what Postgres does:
   875  				// http://www.postgresql.org/docs/9.4/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE
   876  				start = s.pos
   877  			}
   878  
   879  		case eof:
   880  			lval.id = ERROR
   881  			lval.str = errUnterminated
   882  			return false
   883  		}
   884  	}
   885  
   886  	if requireUTF8 && !utf8.Valid(buf) {
   887  		lval.id = ERROR
   888  		lval.str = errInvalidUTF8
   889  		return false
   890  	}
   891  
   892  	lval.str = s.finishString(buf)
   893  	return true
   894  }
   895  
   896  // scanDollarQuotedString scans for so called dollar-quoted strings, which start/end with either $$ or $tag$, where
   897  // tag is some arbitrary string.  e.g. $$a string$$ or $escaped$a string$escaped$.
   898  func (s *scanner) scanDollarQuotedString(lval *sqlSymType) bool {
   899  	buf := s.buffer()
   900  	start := s.pos
   901  
   902  	foundStartTag := false
   903  	possibleEndTag := false
   904  	startTagIndex := -1
   905  	var startTag string
   906  
   907  outer:
   908  	for {
   909  		ch := s.peek()
   910  		switch ch {
   911  		case '$':
   912  			s.pos++
   913  			if foundStartTag {
   914  				if possibleEndTag {
   915  					if len(startTag) == startTagIndex {
   916  						// Found end tag.
   917  						buf = append(buf, s.in[start+len(startTag)+1:s.pos-len(startTag)-2]...)
   918  						break outer
   919  					} else {
   920  						// Was not the end tag but the current $ might be the start of the end tag we are looking for, so
   921  						// just reset the startTagIndex.
   922  						startTagIndex = 0
   923  					}
   924  				} else {
   925  					possibleEndTag = true
   926  					startTagIndex = 0
   927  				}
   928  			} else {
   929  				startTag = s.in[start : s.pos-1]
   930  				foundStartTag = true
   931  			}
   932  
   933  		case eof:
   934  			if foundStartTag {
   935  				// A start tag was found, therefore we expect an end tag before the eof, otherwise it is an error.
   936  				lval.id = ERROR
   937  				lval.str = errUnterminated
   938  			} else {
   939  				// This is not a dollar-quoted string, reset the pos back to the start.
   940  				s.pos = start
   941  			}
   942  			return false
   943  
   944  		default:
   945  			// If we haven't found a start tag yet, check whether the current characters is a valid for a tag.
   946  			if !foundStartTag && !lex.IsIdentStart(ch) {
   947  				return false
   948  			}
   949  			s.pos++
   950  			if possibleEndTag {
   951  				// Check whether this could be the end tag.
   952  				if startTagIndex >= len(startTag) || ch != int(startTag[startTagIndex]) {
   953  					// This is not the end tag we are looking for.
   954  					possibleEndTag = false
   955  					startTagIndex = -1
   956  				} else {
   957  					startTagIndex++
   958  				}
   959  			}
   960  		}
   961  	}
   962  
   963  	if !utf8.Valid(buf) {
   964  		lval.id = ERROR
   965  		lval.str = errInvalidUTF8
   966  		return false
   967  	}
   968  
   969  	lval.str = s.finishString(buf)
   970  	return true
   971  }
   972  
   973  // SplitFirstStatement returns the length of the prefix of the string up to and
   974  // including the first semicolon that separates statements. If there is no
   975  // semicolon, returns ok=false.
   976  func SplitFirstStatement(sql string) (pos int, ok bool) {
   977  	s := makeScanner(sql)
   978  	var lval sqlSymType
   979  	for {
   980  		s.scan(&lval)
   981  		switch lval.id {
   982  		case 0, ERROR:
   983  			return 0, false
   984  		case ';':
   985  			return s.pos, true
   986  		}
   987  	}
   988  }
   989  
   990  // Tokens decomposes the input into lexical tokens.
   991  func Tokens(sql string) (tokens []TokenString, ok bool) {
   992  	s := makeScanner(sql)
   993  	for {
   994  		var lval sqlSymType
   995  		s.scan(&lval)
   996  		if lval.id == ERROR {
   997  			return nil, false
   998  		}
   999  		if lval.id == 0 {
  1000  			break
  1001  		}
  1002  		tokens = append(tokens, TokenString{TokenID: lval.id, Str: lval.str})
  1003  	}
  1004  	return tokens, true
  1005  }
  1006  
  1007  // TokenString is the unit value returned by Tokens.
  1008  type TokenString struct {
  1009  	TokenID int32
  1010  	Str     string
  1011  }
  1012  
  1013  // LastLexicalToken returns the last lexical token. If the string has no lexical
  1014  // tokens, returns 0 and ok=false.
  1015  func LastLexicalToken(sql string) (lastTok int, ok bool) {
  1016  	s := makeScanner(sql)
  1017  	var lval sqlSymType
  1018  	for {
  1019  		last := lval.id
  1020  		s.scan(&lval)
  1021  		if lval.id == 0 {
  1022  			return int(last), last != 0
  1023  		}
  1024  	}
  1025  }
  1026  
  1027  // EndsInSemicolon returns true if the last lexical token is a semicolon.
  1028  func EndsInSemicolon(sql string) bool {
  1029  	lastTok, ok := LastLexicalToken(sql)
  1030  	return ok && lastTok == ';'
  1031  }