github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqlparse/tidbparser/parser/yy_parser.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package parser 15 16 import ( 17 "math" 18 "regexp" 19 "strconv" 20 "unicode" 21 22 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/ast" 23 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/mysql" 24 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/terror" 25 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/types" 26 "github.com/bingoohuang/gg/pkg/sqlparse/tidbparser/dependency/util/hack" 27 "github.com/juju/errors" 28 ) 29 30 const ( 31 codeErrParse = terror.ErrCode(mysql.ErrParse) 32 codeErrSyntax = terror.ErrCode(mysql.ErrSyntax) 33 ) 34 35 var ( 36 // ErrSyntax returns for sql syntax error. 37 ErrSyntax = terror.ClassParser.New(codeErrSyntax, mysql.MySQLErrName[mysql.ErrSyntax]) 38 // ErrParse returns for sql parse error. 39 ErrParse = terror.ClassParser.New(codeErrParse, mysql.MySQLErrName[mysql.ErrParse]) 40 // SpecFieldPattern special result field pattern 41 SpecFieldPattern = regexp.MustCompile(`(\/\*!(M?[0-9]{5,6})?|\*\/)`) 42 specCodePattern = regexp.MustCompile(`\/\*!(M?[0-9]{5,6})?([^*]|\*+[^*/])*\*+\/`) 43 specCodeStart = regexp.MustCompile(`^\/\*!(M?[0-9]{5,6})?[ \t]*`) 44 specCodeEnd = regexp.MustCompile(`[ \t]*\*\/$`) 45 ) 46 47 func init() { 48 parserMySQLErrCodes := map[terror.ErrCode]uint16{ 49 codeErrSyntax: mysql.ErrSyntax, 50 codeErrParse: mysql.ErrParse, 51 } 52 terror.ErrClassToMySQLCodes[terror.ClassParser] = parserMySQLErrCodes 53 } 54 55 // TrimComment trim comment for special comment code of MySQL. 56 func TrimComment(txt string) string { 57 txt = specCodeStart.ReplaceAllString(txt, "") 58 return specCodeEnd.ReplaceAllString(txt, "") 59 } 60 61 // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. 62 type Parser struct { 63 charset string 64 collation string 65 result []ast.StmtNode 66 src string 67 lexer Scanner 68 69 // the following fields are used by yyParse to reduce allocation. 70 cache []yySymType 71 yylval yySymType 72 yyVAL yySymType 73 } 74 75 type stmtTexter interface { 76 stmtText() string 77 } 78 79 // New returns a Parser object. 80 func New() *Parser { 81 return &Parser{ 82 cache: make([]yySymType, 200), 83 } 84 } 85 86 // Parse parses a query string to raw ast.StmtNode. 87 // If charset or collation is "", default charset and collation will be used. 88 func (parser *Parser) Parse(sql, charset, collation string) ([]ast.StmtNode, error) { 89 if charset == "" { 90 charset = mysql.DefaultCharset 91 } 92 if collation == "" { 93 collation = mysql.DefaultCollationName 94 } 95 parser.charset = charset 96 parser.collation = collation 97 parser.src = sql 98 parser.result = parser.result[:0] 99 100 var l yyLexer 101 parser.lexer.reset(sql) 102 l = &parser.lexer 103 yyParse(l, parser) 104 105 if len(l.Errors()) != 0 { 106 return nil, errors.Trace(l.Errors()[0]) 107 } 108 for _, stmt := range parser.result { 109 ast.SetFlag(stmt) 110 } 111 return parser.result, nil 112 } 113 114 // ParseOneStmt parses a query and returns an ast.StmtNode. 115 // The query must have one statement, otherwise ErrSyntax is returned. 116 func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) { 117 stmts, err := parser.Parse(sql, charset, collation) 118 if err != nil { 119 return nil, errors.Trace(err) 120 } 121 if len(stmts) != 1 { 122 return nil, ErrSyntax 123 } 124 ast.SetFlag(stmts[0]) 125 return stmts[0], nil 126 } 127 128 // SetSQLMode sets the SQL mode for parser. 129 func (parser *Parser) SetSQLMode(mode mysql.SQLMode) { 130 parser.lexer.SetSQLMode(mode) 131 } 132 133 // ParseErrorWith returns "You have a syntax error near..." error message compatible with mysql. 134 func ParseErrorWith(errstr string, lineno int) *terror.Error { 135 if len(errstr) > mysql.ErrTextLength { 136 errstr = errstr[:mysql.ErrTextLength] 137 } 138 return ErrParse.GenByArgs(mysql.MySQLErrName[mysql.ErrSyntax], errstr, lineno) 139 } 140 141 // The select statement is not at the end of the whole statement, if the last 142 // field text was set from its offset to the end of the src string, update 143 // the last field text. 144 func (parser *Parser) setLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { 145 lastField := st.Fields.Fields[len(st.Fields.Fields)-1] 146 if lastField.Offset+len(lastField.Text()) >= len(parser.src)-1 { 147 lastField.SetText(parser.src[lastField.Offset:lastEnd]) 148 } 149 } 150 151 func (parser *Parser) startOffset(v *yySymType) int { 152 return v.offset 153 } 154 155 func (parser *Parser) endOffset(v *yySymType) int { 156 offset := v.offset 157 for offset > 0 && unicode.IsSpace(rune(parser.src[offset-1])) { 158 offset-- 159 } 160 return offset 161 } 162 163 func toInt(l yyLexer, lval *yySymType, str string) int { 164 n, err := strconv.ParseUint(str, 10, 64) 165 if err != nil { 166 e := err.(*strconv.NumError) 167 if e.Err == strconv.ErrRange { 168 // TODO: toDecimal maybe out of range still. 169 // This kind of error should be throw to higher level, because truncated data maybe legal. 170 // For example, this SQL returns error: 171 // create table test (id decimal(30, 0)); 172 // insert into test values(123456789012345678901234567890123094839045793405723406801943850); 173 // While this SQL: 174 // select 1234567890123456789012345678901230948390457934057234068019438509023041874359081325875128590860234789847359871045943057; 175 // get value 99999999999999999999999999999999999999999999999999999999999999999 176 return toDecimal(l, lval, str) 177 } 178 l.Errorf("integer literal: %v", err) 179 return int(unicode.ReplacementChar) 180 } 181 182 switch { 183 case n < math.MaxInt64: 184 lval.item = int64(n) 185 default: 186 lval.item = n 187 } 188 return intLit 189 } 190 191 func toDecimal(l yyLexer, lval *yySymType, str string) int { 192 dec := new(types.MyDecimal) 193 err := dec.FromString(hack.Slice(str)) 194 if err != nil { 195 l.Errorf("decimal literal: %v", err) 196 } 197 lval.item = dec 198 return decLit 199 } 200 201 func toFloat(l yyLexer, lval *yySymType, str string) int { 202 n, err := strconv.ParseFloat(str, 64) 203 if err != nil { 204 l.Errorf("float literal: %v", err) 205 return int(unicode.ReplacementChar) 206 } 207 208 lval.item = n 209 return floatLit 210 } 211 212 // See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html 213 func toHex(l yyLexer, lval *yySymType, str string) int { 214 h, err := types.NewHexLiteral(str) 215 if err != nil { 216 l.Errorf("hex literal: %v", err) 217 return int(unicode.ReplacementChar) 218 } 219 lval.item = h 220 return hexLit 221 } 222 223 // See https://dev.mysql.com/doc/refman/5.7/en/bit-type.html 224 func toBit(l yyLexer, lval *yySymType, str string) int { 225 b, err := types.NewBitLiteral(str) 226 if err != nil { 227 l.Errorf("bit literal: %v", err) 228 return int(unicode.ReplacementChar) 229 } 230 lval.item = b 231 return bitLit 232 } 233 234 func getUint64FromNUM(num interface{}) uint64 { 235 switch v := num.(type) { 236 case int64: 237 return uint64(v) 238 case uint64: 239 return v 240 } 241 return 0 242 }