github.com/XiaoMi/Gaea@v1.2.5/parser/yy_parser.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package parser 15 16 import ( 17 "fmt" 18 "math" 19 "regexp" 20 "strconv" 21 "strings" 22 "unicode" 23 24 "github.com/pingcap/errors" 25 26 "github.com/XiaoMi/Gaea/mysql" 27 "github.com/XiaoMi/Gaea/parser/ast" 28 pformat "github.com/XiaoMi/Gaea/parser/format" 29 "github.com/XiaoMi/Gaea/parser/terror" 30 ) 31 32 const ( 33 codeErrParse = terror.ErrCode(mysql.ErrParse) 34 codeErrSyntax = terror.ErrCode(mysql.ErrSyntax) 35 ) 36 37 var ( 38 // ErrSyntax returns for sql syntax error. 39 ErrSyntax = terror.ClassParser.New(codeErrSyntax, mysql.MySQLErrName[mysql.ErrSyntax]) 40 // ErrParse returns for sql parse error. 41 ErrParse = terror.ClassParser.New(codeErrParse, mysql.MySQLErrName[mysql.ErrParse]) 42 // SpecFieldPattern special result field pattern 43 SpecFieldPattern = regexp.MustCompile(`(\/\*!(M?[0-9]{5,6})?|\*\/)`) 44 specCodePattern = regexp.MustCompile(`\/\*!(M?[0-9]{5,6})?([^*]|\*+[^*/])*\*+\/`) 45 specCodeStart = regexp.MustCompile(`^\/\*!(M?[0-9]{5,6})?[ \t]*`) 46 specCodeEnd = regexp.MustCompile(`[ \t]*\*\/$`) 47 ) 48 49 func init() { 50 parserMySQLErrCodes := map[terror.ErrCode]uint16{ 51 codeErrSyntax: mysql.ErrSyntax, 52 codeErrParse: mysql.ErrParse, 53 } 54 terror.ErrClassToMySQLCodes[terror.ClassParser] = parserMySQLErrCodes 55 } 56 57 // TrimComment trim comment for special comment code of MySQL. 58 func TrimComment(txt string) string { 59 txt = specCodeStart.ReplaceAllString(txt, "") 60 return specCodeEnd.ReplaceAllString(txt, "") 61 } 62 63 // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. 64 type Parser struct { 65 charset string 66 collation string 67 result []ast.StmtNode 68 src string 69 lexer Scanner 70 71 // the following fields are used by yyParse to reduce allocation. 72 cache []yySymType 73 yylval yySymType 74 yyVAL yySymType 75 } 76 77 type stmtTexter interface { 78 stmtText() string 79 } 80 81 // New returns a Parser object. 82 func New() *Parser { 83 if ast.NewValueExpr == nil || 84 ast.NewParamMarkerExpr == nil || 85 ast.NewHexLiteral == nil || 86 ast.NewBitLiteral == nil { 87 panic("no parser driver (forgotten import?) https://github.com/pingcap/parser/issues/43") 88 } 89 90 return &Parser{ 91 cache: make([]yySymType, 200), 92 } 93 } 94 95 // Parse parses a query string to raw ast.StmtNode. 96 // If charset or collation is "", default charset and collation will be used. 97 func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) { 98 if charset == "" { 99 charset = mysql.DefaultCharset 100 } 101 if collation == "" { 102 collation = mysql.DefaultCollationName 103 } 104 parser.charset = charset 105 parser.collation = collation 106 parser.src = sql 107 parser.result = parser.result[:0] 108 109 var l yyLexer 110 parser.lexer.reset(sql) 111 l = &parser.lexer 112 yyParse(l, parser) 113 114 warns, errs := l.Errors() 115 if len(warns) > 0 { 116 warns = append([]error(nil), warns...) 117 } else { 118 warns = nil 119 } 120 if len(errs) != 0 { 121 return nil, warns, errors.Trace(errs[0]) 122 } 123 for _, stmt := range parser.result { 124 ast.SetFlag(stmt) 125 } 126 return parser.result, warns, nil 127 } 128 129 func (parser *Parser) lastErrorAsWarn() { 130 if len(parser.lexer.errs) == 0 { 131 return 132 } 133 parser.lexer.warns = append(parser.lexer.warns, parser.lexer.errs[len(parser.lexer.errs)-1]) 134 parser.lexer.errs = parser.lexer.errs[:len(parser.lexer.errs)-1] 135 } 136 137 // ParseOneStmt parses a query and returns an ast.StmtNode. 138 // The query must have one statement, otherwise ErrSyntax is returned. 139 func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) { 140 stmts, _, err := parser.Parse(sql, charset, collation) 141 if err != nil { 142 return nil, errors.Trace(err) 143 } 144 if len(stmts) != 1 { 145 return nil, ErrSyntax 146 } 147 ast.SetFlag(stmts[0]) 148 return stmts[0], nil 149 } 150 151 // SetSQLMode sets the SQL mode for parser. 152 func (parser *Parser) SetSQLMode(mode mysql.SQLMode) { 153 parser.lexer.SetSQLMode(mode) 154 } 155 156 // EnableWindowFunc controls whether the parser to parse syntax related with window function. 157 func (parser *Parser) EnableWindowFunc(val bool) { 158 parser.lexer.EnableWindowFunc(val) 159 } 160 161 // ParseErrorWith returns "You have a syntax error near..." error message compatible with mysql. 162 func ParseErrorWith(errstr string, lineno int) error { 163 if len(errstr) > mysql.ErrTextLength { 164 errstr = errstr[:mysql.ErrTextLength] 165 } 166 return fmt.Errorf("near '%-.80s' at line %d", errstr, lineno) 167 } 168 169 // The select statement is not at the end of the whole statement, if the last 170 // field text was set from its offset to the end of the src string, update 171 // the last field text. 172 func (parser *Parser) setLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { 173 lastField := st.Fields.Fields[len(st.Fields.Fields)-1] 174 if lastField.Offset+len(lastField.Text()) >= len(parser.src)-1 { 175 lastField.SetText(parser.src[lastField.Offset:lastEnd]) 176 } 177 } 178 179 func (parser *Parser) startOffset(v *yySymType) int { 180 return v.offset 181 } 182 183 func (parser *Parser) endOffset(v *yySymType) int { 184 offset := v.offset 185 for offset > 0 && unicode.IsSpace(rune(parser.src[offset-1])) { 186 offset-- 187 } 188 return offset 189 } 190 191 func toInt(l yyLexer, lval *yySymType, str string) int { 192 n, err := strconv.ParseUint(str, 10, 64) 193 if err != nil { 194 e := err.(*strconv.NumError) 195 if e.Err == strconv.ErrRange { 196 // TODO: toDecimal maybe out of range still. 197 // This kind of error should be throw to higher level, because truncated data maybe legal. 198 // For example, this SQL returns error: 199 // create table test (id decimal(30, 0)); 200 // insert into test values(123456789012345678901234567890123094839045793405723406801943850); 201 // While this SQL: 202 // select 1234567890123456789012345678901230948390457934057234068019438509023041874359081325875128590860234789847359871045943057; 203 // get value 99999999999999999999999999999999999999999999999999999999999999999 204 return toDecimal(l, lval, str) 205 } 206 l.Errorf("integer literal: %v", err) 207 return int(unicode.ReplacementChar) 208 } 209 210 switch { 211 case n < math.MaxInt64: 212 lval.item = int64(n) 213 default: 214 lval.item = n 215 } 216 return intLit 217 } 218 219 func toDecimal(l yyLexer, lval *yySymType, str string) int { 220 dec, err := ast.NewDecimal(str) 221 if err != nil { 222 l.Errorf("decimal literal: %v", err) 223 } 224 lval.item = dec 225 return decLit 226 } 227 228 func toFloat(l yyLexer, lval *yySymType, str string) int { 229 n, err := strconv.ParseFloat(str, 64) 230 if err != nil { 231 l.Errorf("float literal: %v", err) 232 return int(unicode.ReplacementChar) 233 } 234 235 lval.item = n 236 return floatLit 237 } 238 239 // See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html 240 func toHex(l yyLexer, lval *yySymType, str string) int { 241 h, err := ast.NewHexLiteral(str) 242 if err != nil { 243 l.Errorf("hex literal: %v", err) 244 return int(unicode.ReplacementChar) 245 } 246 lval.item = h 247 return hexLit 248 } 249 250 // See https://dev.mysql.com/doc/refman/5.7/en/bit-type.html 251 func toBit(l yyLexer, lval *yySymType, str string) int { 252 b, err := ast.NewBitLiteral(str) 253 if err != nil { 254 l.Errorf("bit literal: %v", err) 255 return int(unicode.ReplacementChar) 256 } 257 lval.item = b 258 return bitLit 259 } 260 261 func getUint64FromNUM(num interface{}) uint64 { 262 switch v := num.(type) { 263 case int64: 264 return uint64(v) 265 case uint64: 266 return v 267 } 268 return 0 269 } 270 271 // ParseSQL only for test, use ClientConn.parser for handling request 272 func ParseSQL(sql string) (ast.StmtNode, error) { 273 ps := New() 274 return ps.ParseOneStmt(sql, "", "") 275 } 276 277 const resultTableNameFlag pformat.RestoreFlags = 0 278 279 // NodeToStringWithoutQuote get node text 280 func NodeToStringWithoutQuote(node ast.Node) (string, error) { 281 s := &strings.Builder{} 282 if err := node.Restore(pformat.NewRestoreCtx(resultTableNameFlag, s)); err != nil { 283 return "", err 284 } 285 return s.String(), nil 286 }