github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/tidbparser/parser/yy_parser.go (about)

     1  // Copyright 2015 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package parser
    15  
    16  import (
    17  	"math"
    18  	"regexp"
    19  	"strconv"
    20  	"unicode"
    21  
    22  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/ast"
    23  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/mysql"
    24  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/terror"
    25  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types"
    26  	"github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/util/hack"
    27  	"github.com/juju/errors"
    28  )
    29  
    30  const (
    31  	codeErrParse  = terror.ErrCode(mysql.ErrParse)
    32  	codeErrSyntax = terror.ErrCode(mysql.ErrSyntax)
    33  )
    34  
    35  var (
    36  	// ErrSyntax returns for sql syntax error.
    37  	ErrSyntax = terror.ClassParser.New(codeErrSyntax, mysql.MySQLErrName[mysql.ErrSyntax])
    38  	// ErrParse returns for sql parse error.
    39  	ErrParse = terror.ClassParser.New(codeErrParse, mysql.MySQLErrName[mysql.ErrParse])
    40  	// SpecFieldPattern special result field pattern
    41  	SpecFieldPattern = regexp.MustCompile(`(\/\*!(M?[0-9]{5,6})?|\*\/)`)
    42  	specCodePattern  = regexp.MustCompile(`\/\*!(M?[0-9]{5,6})?([^*]|\*+[^*/])*\*+\/`)
    43  	specCodeStart    = regexp.MustCompile(`^\/\*!(M?[0-9]{5,6})?[ \t]*`)
    44  	specCodeEnd      = regexp.MustCompile(`[ \t]*\*\/$`)
    45  )
    46  
    47  func init() {
    48  	parserMySQLErrCodes := map[terror.ErrCode]uint16{
    49  		codeErrSyntax: mysql.ErrSyntax,
    50  		codeErrParse:  mysql.ErrParse,
    51  	}
    52  	terror.ErrClassToMySQLCodes[terror.ClassParser] = parserMySQLErrCodes
    53  }
    54  
    55  // TrimComment trim comment for special comment code of MySQL.
    56  func TrimComment(txt string) string {
    57  	txt = specCodeStart.ReplaceAllString(txt, "")
    58  	return specCodeEnd.ReplaceAllString(txt, "")
    59  }
    60  
    61  // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function.
    62  type Parser struct {
    63  	charset   string
    64  	collation string
    65  	result    []ast.StmtNode
    66  	src       string
    67  	lexer     Scanner
    68  
    69  	// the following fields are used by yyParse to reduce allocation.
    70  	cache  []yySymType
    71  	yylval yySymType
    72  	yyVAL  yySymType
    73  }
    74  
    75  type stmtTexter interface {
    76  	stmtText() string
    77  }
    78  
    79  // New returns a Parser object.
    80  func New() *Parser {
    81  	return &Parser{
    82  		cache: make([]yySymType, 200),
    83  	}
    84  }
    85  
    86  // Parse parses a query string to raw ast.StmtNode.
    87  // If charset or collation is "", default charset and collation will be used.
    88  func (parser *Parser) Parse(sql, charset, collation string) ([]ast.StmtNode, error) {
    89  	if charset == "" {
    90  		charset = mysql.DefaultCharset
    91  	}
    92  	if collation == "" {
    93  		collation = mysql.DefaultCollationName
    94  	}
    95  	parser.charset = charset
    96  	parser.collation = collation
    97  	parser.src = sql
    98  	parser.result = parser.result[:0]
    99  
   100  	var l yyLexer
   101  	parser.lexer.reset(sql)
   102  	l = &parser.lexer
   103  	yyParse(l, parser)
   104  
   105  	if len(l.Errors()) != 0 {
   106  		return nil, errors.Trace(l.Errors()[0])
   107  	}
   108  	for _, stmt := range parser.result {
   109  		ast.SetFlag(stmt)
   110  	}
   111  	return parser.result, nil
   112  }
   113  
   114  // ParseOneStmt parses a query and returns an ast.StmtNode.
   115  // The query must have one statement, otherwise ErrSyntax is returned.
   116  func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) {
   117  	stmts, err := parser.Parse(sql, charset, collation)
   118  	if err != nil {
   119  		return nil, errors.Trace(err)
   120  	}
   121  	if len(stmts) != 1 {
   122  		return nil, ErrSyntax
   123  	}
   124  	ast.SetFlag(stmts[0])
   125  	return stmts[0], nil
   126  }
   127  
   128  // SetSQLMode sets the SQL mode for parser.
   129  func (parser *Parser) SetSQLMode(mode mysql.SQLMode) {
   130  	parser.lexer.SetSQLMode(mode)
   131  }
   132  
   133  // ParseErrorWith returns "You have a syntax error near..." error message compatible with mysql.
   134  func ParseErrorWith(errstr string, lineno int) *terror.Error {
   135  	if len(errstr) > mysql.ErrTextLength {
   136  		errstr = errstr[:mysql.ErrTextLength]
   137  	}
   138  	return ErrParse.GenByArgs(mysql.MySQLErrName[mysql.ErrSyntax], errstr, lineno)
   139  }
   140  
   141  // The select statement is not at the end of the whole statement, if the last
   142  // field text was set from its offset to the end of the src string, update
   143  // the last field text.
   144  func (parser *Parser) setLastSelectFieldText(st *ast.SelectStmt, lastEnd int) {
   145  	lastField := st.Fields.Fields[len(st.Fields.Fields)-1]
   146  	if lastField.Offset+len(lastField.Text()) >= len(parser.src)-1 {
   147  		lastField.SetText(parser.src[lastField.Offset:lastEnd])
   148  	}
   149  }
   150  
   151  func (parser *Parser) startOffset(v *yySymType) int {
   152  	return v.offset
   153  }
   154  
   155  func (parser *Parser) endOffset(v *yySymType) int {
   156  	offset := v.offset
   157  	for offset > 0 && unicode.IsSpace(rune(parser.src[offset-1])) {
   158  		offset--
   159  	}
   160  	return offset
   161  }
   162  
   163  func toInt(l yyLexer, lval *yySymType, str string) int {
   164  	n, err := strconv.ParseUint(str, 10, 64)
   165  	if err != nil {
   166  		e := err.(*strconv.NumError)
   167  		if e.Err == strconv.ErrRange {
   168  			// TODO: toDecimal maybe out of range still.
   169  			// This kind of error should be throw to higher level, because truncated data maybe legal.
   170  			// For example, this SQL returns error:
   171  			// create table test (id decimal(30, 0));
   172  			// insert into test values(123456789012345678901234567890123094839045793405723406801943850);
   173  			// While this SQL:
   174  			// select 1234567890123456789012345678901230948390457934057234068019438509023041874359081325875128590860234789847359871045943057;
   175  			// get value 99999999999999999999999999999999999999999999999999999999999999999
   176  			return toDecimal(l, lval, str)
   177  		}
   178  		l.Errorf("integer literal: %v", err)
   179  		return int(unicode.ReplacementChar)
   180  	}
   181  
   182  	switch {
   183  	case n < math.MaxInt64:
   184  		lval.item = int64(n)
   185  	default:
   186  		lval.item = n
   187  	}
   188  	return intLit
   189  }
   190  
   191  func toDecimal(l yyLexer, lval *yySymType, str string) int {
   192  	dec := new(types.MyDecimal)
   193  	err := dec.FromString(hack.Slice(str))
   194  	if err != nil {
   195  		l.Errorf("decimal literal: %v", err)
   196  	}
   197  	lval.item = dec
   198  	return decLit
   199  }
   200  
   201  func toFloat(l yyLexer, lval *yySymType, str string) int {
   202  	n, err := strconv.ParseFloat(str, 64)
   203  	if err != nil {
   204  		l.Errorf("float literal: %v", err)
   205  		return int(unicode.ReplacementChar)
   206  	}
   207  
   208  	lval.item = n
   209  	return floatLit
   210  }
   211  
   212  // See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html
   213  func toHex(l yyLexer, lval *yySymType, str string) int {
   214  	h, err := types.NewHexLiteral(str)
   215  	if err != nil {
   216  		l.Errorf("hex literal: %v", err)
   217  		return int(unicode.ReplacementChar)
   218  	}
   219  	lval.item = h
   220  	return hexLit
   221  }
   222  
   223  // See https://dev.mysql.com/doc/refman/5.7/en/bit-type.html
   224  func toBit(l yyLexer, lval *yySymType, str string) int {
   225  	b, err := types.NewBitLiteral(str)
   226  	if err != nil {
   227  		l.Errorf("bit literal: %v", err)
   228  		return int(unicode.ReplacementChar)
   229  	}
   230  	lval.item = b
   231  	return bitLit
   232  }
   233  
   234  func getUint64FromNUM(num interface{}) uint64 {
   235  	switch v := num.(type) {
   236  	case int64:
   237  		return uint64(v)
   238  	case uint64:
   239  		return v
   240  	}
   241  	return 0
   242  }