github.com/waldiirawan/apm-agent-go/v2@v2.2.2/sqlutil/scanner.go (about)

     1  // Licensed to Elasticsearch B.V. under one or more contributor
     2  // license agreements. See the NOTICE file distributed with
     3  // this work for additional information regarding copyright
     4  // ownership. Elasticsearch B.V. licenses this file to you under
     5  // the Apache License, Version 2.0 (the "License"); you may
     6  // not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing,
    12  // software distributed under the License is distributed on an
    13  // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    14  // KIND, either express or implied.  See the License for the
    15  // specific language governing permissions and limitations
    16  // under the License.
    17  
    18  package sqlutil // import "github.com/waldiirawan/apm-agent-go/v2/sqlutil"
    19  
    20  import (
    21  	"strings"
    22  	"unicode"
    23  	"unicode/utf8"
    24  )
    25  
    26  // Scanner is the struct used to generate SQL
    27  // tokens for the parser.
    28  type Scanner struct {
    29  	input string
    30  	start int // text start pos in bytes
    31  	end   int // text end pos in bytes
    32  	pos   int // read pos in bytes
    33  	tok   Token
    34  }
    35  
    36  // NewScanner creates a new Scanner for sql.
    37  func NewScanner(sql string) *Scanner {
    38  	return &Scanner{input: sql}
    39  }
    40  
    41  // Token returns the most recently scanned token.
    42  func (s *Scanner) Token() Token {
    43  	return s.tok
    44  }
    45  
    46  // Text returns the portion of the string that relates to
    47  // the most recently scanned token.
    48  func (s *Scanner) Text() string {
    49  	return s.input[s.start:s.end]
    50  }
    51  
    52  // Scan scans for the next token and returns true if one was
    53  // found, false if the end of the input stream was reached.
    54  // When Scan returns true, the token type can be obtained by
    55  // calling the Token() method, and the text can be obtained
    56  // by calling the Text() method.
    57  func (s *Scanner) Scan() bool {
    58  	s.tok = s.scan()
    59  	return s.tok != eof
    60  }
    61  
    62  func (s *Scanner) scan() Token {
    63  	r, ok := s.next()
    64  	if !ok {
    65  		return eof
    66  	}
    67  	for unicode.IsSpace(r) {
    68  		if r, ok = s.next(); !ok {
    69  			return eof
    70  		}
    71  	}
    72  	s.start = s.pos - utf8.RuneLen(r)
    73  
    74  	if r == '_' || unicode.IsLetter(r) {
    75  		return s.scanKeywordOrIdentifier(r != '_')
    76  	} else if unicode.IsDigit(r) {
    77  		return s.scanNumericLiteral()
    78  	}
    79  
    80  	switch r {
    81  	case '\'':
    82  		// Standard string literal.
    83  		return s.scanStringLiteral()
    84  	case '"':
    85  		// Standard double-quoted identifier.
    86  		//
    87  		// NOTE(axw) MySQL will treat " as a
    88  		// string literal delimiter by default,
    89  		// but we assume standard SQL and treat
    90  		// it as a identifier delimiter.
    91  		return s.scanQuotedIdentifier('"')
    92  	case '[':
    93  		// T-SQL bracket-quoted identifier.
    94  		return s.scanQuotedIdentifier(']')
    95  	case '`':
    96  		// MySQL-style backtick-quoted identifier.
    97  		return s.scanQuotedIdentifier('`')
    98  	case '(':
    99  		return LPAREN
   100  	case ')':
   101  		return RPAREN
   102  	case '-':
   103  		if next, ok := s.peek(); ok && next == '-' {
   104  			// -- comment
   105  			s.next()
   106  			return s.scanSimpleComment()
   107  		}
   108  		return OTHER
   109  	case '/':
   110  		if next, ok := s.peek(); ok {
   111  			switch next {
   112  			case '*':
   113  				// /* comment */
   114  				s.next()
   115  				return s.scanBracketedComment()
   116  			case '/':
   117  				// // comment
   118  				s.next()
   119  				return s.scanSimpleComment()
   120  			}
   121  		}
   122  		return OTHER
   123  	case '.':
   124  		return PERIOD
   125  	case '$':
   126  		next, ok := s.peek()
   127  		if !ok {
   128  			break
   129  		}
   130  		if unicode.IsDigit(next) {
   131  			// This is a variable like "$1".
   132  			for {
   133  				if next, ok := s.peek(); !ok || !unicode.IsDigit(next) {
   134  					break
   135  				}
   136  				s.next()
   137  			}
   138  			return OTHER
   139  		} else if next == '$' || next == '_' || unicode.IsLetter(next) {
   140  			// PostgreSQL supports dollar-quoted string literal syntax,
   141  			// like $foo$...$foo$. The tag (foo in this case) is optional,
   142  			// and if present follows identifier rules.
   143  			for {
   144  				r, ok := s.next()
   145  				if !ok {
   146  					// Unknown token starting with $ until
   147  					// EOF, just ignore it.
   148  					return OTHER
   149  				}
   150  				switch {
   151  				case r == '$':
   152  					// This marks the end of the initial $foo$.
   153  					tag := s.Text()
   154  					if i := strings.Index(s.input[s.pos:], tag); i >= 0 {
   155  						s.end += i + len(tag)
   156  						s.pos += i + len(tag)
   157  						return STRING
   158  					}
   159  				case unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_':
   160  					// Identifier rune, consume.
   161  				case unicode.IsSpace(r):
   162  					// Unknown token starting with $,
   163  					// consume runes until space.
   164  					s.end -= utf8.RuneLen(r)
   165  					return OTHER
   166  				}
   167  			}
   168  		}
   169  		return OTHER
   170  	}
   171  	return OTHER
   172  }
   173  
   174  func (s *Scanner) scanKeywordOrIdentifier(maybeKeyword bool) Token {
   175  loop:
   176  	for {
   177  		r, ok := s.peek()
   178  		if !ok {
   179  			break loop
   180  		}
   181  		switch {
   182  		case unicode.IsLetter(r):
   183  		case unicode.IsDigit(r) || r == '_' || r == '$':
   184  			maybeKeyword = false
   185  		default:
   186  			break loop
   187  		}
   188  		s.next()
   189  	}
   190  	if !maybeKeyword {
   191  		return IDENT
   192  	}
   193  	text := s.Text()
   194  	if len(text) >= len(keywords) {
   195  		return IDENT
   196  	}
   197  	for _, token := range keywords[len(text)] {
   198  		if strings.EqualFold(text, token.String()) {
   199  			return token
   200  		}
   201  	}
   202  	return IDENT
   203  }
   204  
   205  func (s *Scanner) scanQuotedIdentifier(delim rune) Token {
   206  loop:
   207  	for {
   208  		r, ok := s.next()
   209  		if !ok {
   210  			return eof
   211  		}
   212  		if r == delim {
   213  			if delim == '"' {
   214  				if r, ok := s.peek(); ok && r == delim {
   215  					// Skip escaped double quotes,
   216  					// e.g. "He said ""great""".
   217  					s.next()
   218  					continue loop
   219  				}
   220  			}
   221  			break
   222  		}
   223  	}
   224  	// Remove quotes from identifier.
   225  	s.start++
   226  	s.end--
   227  	return IDENT
   228  }
   229  
   230  func (s *Scanner) scanNumericLiteral() Token {
   231  	var havePeriod bool
   232  	var haveExponent bool
   233  	for {
   234  		r, ok := s.peek()
   235  		if !ok {
   236  			return NUMBER
   237  		}
   238  		if unicode.IsDigit(r) {
   239  			s.next()
   240  			continue
   241  		}
   242  		switch r {
   243  		case '.':
   244  			if havePeriod {
   245  				return NUMBER
   246  			}
   247  			s.next()
   248  			havePeriod = true
   249  		case 'e', 'E':
   250  			if haveExponent {
   251  				return NUMBER
   252  			}
   253  			s.next()
   254  			haveExponent = true
   255  			if r, ok := s.peek(); ok && (r == '+' || r == '-') {
   256  				s.next()
   257  			}
   258  		default:
   259  			return NUMBER
   260  		}
   261  	}
   262  }
   263  
   264  func (s *Scanner) scanStringLiteral() Token {
   265  	const delim = '\''
   266  	for {
   267  		r, ok := s.next()
   268  		if !ok {
   269  			return eof
   270  		}
   271  		if r == '\\' {
   272  			// Skip escaped character, e.g. 'what\'s up?'
   273  			s.next()
   274  			continue
   275  		}
   276  		if r != delim {
   277  			continue
   278  		}
   279  		if r, ok := s.peek(); !ok || r != delim {
   280  			return STRING
   281  		}
   282  		// Two ' characters next to each other
   283  		// are collapsed in a string literal,
   284  		// rather than escaping the string. We
   285  		// don't care about string values, so
   286  		// we don't collapse.
   287  		s.next()
   288  	}
   289  }
   290  
   291  func (s *Scanner) scanSimpleComment() Token {
   292  	for {
   293  		if r, ok := s.next(); !ok || r == '\n' {
   294  			return COMMENT
   295  		}
   296  	}
   297  }
   298  
   299  func (s *Scanner) scanBracketedComment() Token {
   300  	nesting := 1
   301  	for {
   302  		r, ok := s.next()
   303  		if !ok {
   304  			return eof
   305  		}
   306  		switch r {
   307  		case '/':
   308  			r, ok := s.peek()
   309  			if ok && r == '*' {
   310  				s.next()
   311  				nesting++
   312  			}
   313  		case '*':
   314  			r, ok := s.peek()
   315  			if ok && r == '/' {
   316  				s.next()
   317  				nesting--
   318  				if nesting == 0 {
   319  					return COMMENT
   320  				}
   321  			}
   322  		}
   323  	}
   324  }
   325  
   326  // next returns the next rune if there is one, and advances
   327  // the scanner position, or returns utf8.RuneError if there
   328  // is no valid next rune. The bool result indicates whether
   329  // a valid rune is returned.
   330  func (s *Scanner) next() (rune, bool) {
   331  	r, rlen := s.peekLen()
   332  	if r != utf8.RuneError {
   333  		s.pos += rlen
   334  		s.end = s.pos
   335  		return r, true
   336  	}
   337  	return r, false
   338  }
   339  
   340  // peek returns the next rune if there is one, or
   341  // utf8.RuneError if not. The bool result indicates
   342  // whether a valid rune is returned.
   343  func (s *Scanner) peek() (rune, bool) {
   344  	r, _ := s.peekLen()
   345  	if r == utf8.RuneError {
   346  		return utf8.RuneError, false
   347  	}
   348  	return r, true
   349  }
   350  
   351  // peekLen returns the next rune (if there is one)
   352  // and its length. If there is no next valid rune,
   353  // utf8.RuneError and a length of -1 are returned.
   354  func (s *Scanner) peekLen() (rune, int) {
   355  	if s.pos >= len(s.input) {
   356  		return utf8.RuneError, -1
   357  	}
   358  	return utf8.DecodeRuneInString(s.input[s.pos:])
   359  }