github.com/XiaoMi/Gaea@v1.2.5/parser/yy_parser.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package parser
    15  
    16  import (
    17  	"fmt"
    18  	"math"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"unicode"
    23  
    24  	"github.com/pingcap/errors"
    25  
    26  	"github.com/XiaoMi/Gaea/mysql"
    27  	"github.com/XiaoMi/Gaea/parser/ast"
    28  	pformat "github.com/XiaoMi/Gaea/parser/format"
    29  	"github.com/XiaoMi/Gaea/parser/terror"
    30  )
    31  
    32  const (
    33  	codeErrParse  = terror.ErrCode(mysql.ErrParse)
    34  	codeErrSyntax = terror.ErrCode(mysql.ErrSyntax)
    35  )
    36  
    37  var (
    38  	// ErrSyntax returns for sql syntax error.
    39  	ErrSyntax = terror.ClassParser.New(codeErrSyntax, mysql.MySQLErrName[mysql.ErrSyntax])
    40  	// ErrParse returns for sql parse error.
    41  	ErrParse = terror.ClassParser.New(codeErrParse, mysql.MySQLErrName[mysql.ErrParse])
    42  	// SpecFieldPattern special result field pattern
    43  	SpecFieldPattern = regexp.MustCompile(`(\/\*!(M?[0-9]{5,6})?|\*\/)`)
    44  	specCodePattern  = regexp.MustCompile(`\/\*!(M?[0-9]{5,6})?([^*]|\*+[^*/])*\*+\/`)
    45  	specCodeStart    = regexp.MustCompile(`^\/\*!(M?[0-9]{5,6})?[ \t]*`)
    46  	specCodeEnd      = regexp.MustCompile(`[ \t]*\*\/$`)
    47  )
    48  
    49  func init() {
    50  	parserMySQLErrCodes := map[terror.ErrCode]uint16{
    51  		codeErrSyntax: mysql.ErrSyntax,
    52  		codeErrParse:  mysql.ErrParse,
    53  	}
    54  	terror.ErrClassToMySQLCodes[terror.ClassParser] = parserMySQLErrCodes
    55  }
    56  
    57  // TrimComment trim comment for special comment code of MySQL.
    58  func TrimComment(txt string) string {
    59  	txt = specCodeStart.ReplaceAllString(txt, "")
    60  	return specCodeEnd.ReplaceAllString(txt, "")
    61  }
    62  
    63  // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function.
    64  type Parser struct {
    65  	charset   string
    66  	collation string
    67  	result    []ast.StmtNode
    68  	src       string
    69  	lexer     Scanner
    70  
    71  	// the following fields are used by yyParse to reduce allocation.
    72  	cache  []yySymType
    73  	yylval yySymType
    74  	yyVAL  yySymType
    75  }
    76  
    77  type stmtTexter interface {
    78  	stmtText() string
    79  }
    80  
    81  // New returns a Parser object.
    82  func New() *Parser {
    83  	if ast.NewValueExpr == nil ||
    84  		ast.NewParamMarkerExpr == nil ||
    85  		ast.NewHexLiteral == nil ||
    86  		ast.NewBitLiteral == nil {
    87  		panic("no parser driver (forgotten import?) https://github.com/pingcap/parser/issues/43")
    88  	}
    89  
    90  	return &Parser{
    91  		cache: make([]yySymType, 200),
    92  	}
    93  }
    94  
    95  // Parse parses a query string to raw ast.StmtNode.
    96  // If charset or collation is "", default charset and collation will be used.
    97  func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) {
    98  	if charset == "" {
    99  		charset = mysql.DefaultCharset
   100  	}
   101  	if collation == "" {
   102  		collation = mysql.DefaultCollationName
   103  	}
   104  	parser.charset = charset
   105  	parser.collation = collation
   106  	parser.src = sql
   107  	parser.result = parser.result[:0]
   108  
   109  	var l yyLexer
   110  	parser.lexer.reset(sql)
   111  	l = &parser.lexer
   112  	yyParse(l, parser)
   113  
   114  	warns, errs := l.Errors()
   115  	if len(warns) > 0 {
   116  		warns = append([]error(nil), warns...)
   117  	} else {
   118  		warns = nil
   119  	}
   120  	if len(errs) != 0 {
   121  		return nil, warns, errors.Trace(errs[0])
   122  	}
   123  	for _, stmt := range parser.result {
   124  		ast.SetFlag(stmt)
   125  	}
   126  	return parser.result, warns, nil
   127  }
   128  
   129  func (parser *Parser) lastErrorAsWarn() {
   130  	if len(parser.lexer.errs) == 0 {
   131  		return
   132  	}
   133  	parser.lexer.warns = append(parser.lexer.warns, parser.lexer.errs[len(parser.lexer.errs)-1])
   134  	parser.lexer.errs = parser.lexer.errs[:len(parser.lexer.errs)-1]
   135  }
   136  
   137  // ParseOneStmt parses a query and returns an ast.StmtNode.
   138  // The query must have one statement, otherwise ErrSyntax is returned.
   139  func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) {
   140  	stmts, _, err := parser.Parse(sql, charset, collation)
   141  	if err != nil {
   142  		return nil, errors.Trace(err)
   143  	}
   144  	if len(stmts) != 1 {
   145  		return nil, ErrSyntax
   146  	}
   147  	ast.SetFlag(stmts[0])
   148  	return stmts[0], nil
   149  }
   150  
   151  // SetSQLMode sets the SQL mode for parser.
   152  func (parser *Parser) SetSQLMode(mode mysql.SQLMode) {
   153  	parser.lexer.SetSQLMode(mode)
   154  }
   155  
   156  // EnableWindowFunc controls whether the parser to parse syntax related with window function.
   157  func (parser *Parser) EnableWindowFunc(val bool) {
   158  	parser.lexer.EnableWindowFunc(val)
   159  }
   160  
   161  // ParseErrorWith returns "You have a syntax error near..." error message compatible with mysql.
   162  func ParseErrorWith(errstr string, lineno int) error {
   163  	if len(errstr) > mysql.ErrTextLength {
   164  		errstr = errstr[:mysql.ErrTextLength]
   165  	}
   166  	return fmt.Errorf("near '%-.80s' at line %d", errstr, lineno)
   167  }
   168  
   169  // The select statement is not at the end of the whole statement, if the last
   170  // field text was set from its offset to the end of the src string, update
   171  // the last field text.
   172  func (parser *Parser) setLastSelectFieldText(st *ast.SelectStmt, lastEnd int) {
   173  	lastField := st.Fields.Fields[len(st.Fields.Fields)-1]
   174  	if lastField.Offset+len(lastField.Text()) >= len(parser.src)-1 {
   175  		lastField.SetText(parser.src[lastField.Offset:lastEnd])
   176  	}
   177  }
   178  
   179  func (parser *Parser) startOffset(v *yySymType) int {
   180  	return v.offset
   181  }
   182  
   183  func (parser *Parser) endOffset(v *yySymType) int {
   184  	offset := v.offset
   185  	for offset > 0 && unicode.IsSpace(rune(parser.src[offset-1])) {
   186  		offset--
   187  	}
   188  	return offset
   189  }
   190  
   191  func toInt(l yyLexer, lval *yySymType, str string) int {
   192  	n, err := strconv.ParseUint(str, 10, 64)
   193  	if err != nil {
   194  		e := err.(*strconv.NumError)
   195  		if e.Err == strconv.ErrRange {
   196  			// TODO: toDecimal maybe out of range still.
   197  			// This kind of error should be throw to higher level, because truncated data maybe legal.
   198  			// For example, this SQL returns error:
   199  			// create table test (id decimal(30, 0));
   200  			// insert into test values(123456789012345678901234567890123094839045793405723406801943850);
   201  			// While this SQL:
   202  			// select 1234567890123456789012345678901230948390457934057234068019438509023041874359081325875128590860234789847359871045943057;
   203  			// get value 99999999999999999999999999999999999999999999999999999999999999999
   204  			return toDecimal(l, lval, str)
   205  		}
   206  		l.Errorf("integer literal: %v", err)
   207  		return int(unicode.ReplacementChar)
   208  	}
   209  
   210  	switch {
   211  	case n < math.MaxInt64:
   212  		lval.item = int64(n)
   213  	default:
   214  		lval.item = n
   215  	}
   216  	return intLit
   217  }
   218  
   219  func toDecimal(l yyLexer, lval *yySymType, str string) int {
   220  	dec, err := ast.NewDecimal(str)
   221  	if err != nil {
   222  		l.Errorf("decimal literal: %v", err)
   223  	}
   224  	lval.item = dec
   225  	return decLit
   226  }
   227  
   228  func toFloat(l yyLexer, lval *yySymType, str string) int {
   229  	n, err := strconv.ParseFloat(str, 64)
   230  	if err != nil {
   231  		l.Errorf("float literal: %v", err)
   232  		return int(unicode.ReplacementChar)
   233  	}
   234  
   235  	lval.item = n
   236  	return floatLit
   237  }
   238  
   239  // See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html
   240  func toHex(l yyLexer, lval *yySymType, str string) int {
   241  	h, err := ast.NewHexLiteral(str)
   242  	if err != nil {
   243  		l.Errorf("hex literal: %v", err)
   244  		return int(unicode.ReplacementChar)
   245  	}
   246  	lval.item = h
   247  	return hexLit
   248  }
   249  
   250  // See https://dev.mysql.com/doc/refman/5.7/en/bit-type.html
   251  func toBit(l yyLexer, lval *yySymType, str string) int {
   252  	b, err := ast.NewBitLiteral(str)
   253  	if err != nil {
   254  		l.Errorf("bit literal: %v", err)
   255  		return int(unicode.ReplacementChar)
   256  	}
   257  	lval.item = b
   258  	return bitLit
   259  }
   260  
   261  func getUint64FromNUM(num interface{}) uint64 {
   262  	switch v := num.(type) {
   263  	case int64:
   264  		return uint64(v)
   265  	case uint64:
   266  		return v
   267  	}
   268  	return 0
   269  }
   270  
   271  // ParseSQL only for test, use ClientConn.parser for handling request
   272  func ParseSQL(sql string) (ast.StmtNode, error) {
   273  	ps := New()
   274  	return ps.ParseOneStmt(sql, "", "")
   275  }
   276  
   277  const resultTableNameFlag pformat.RestoreFlags = 0
   278  
   279  // NodeToStringWithoutQuote get node text
   280  func NodeToStringWithoutQuote(node ast.Node) (string, error) {
   281  	s := &strings.Builder{}
   282  	if err := node.Restore(pformat.NewRestoreCtx(resultTableNameFlag, s)); err != nil {
   283  		return "", err
   284  	}
   285  	return s.String(), nil
   286  }