github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/yy_parser.go (about) 1 // Copyright 2015 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package parser 15 16 import ( 17 "fmt" 18 "math" 19 "regexp" 20 "strconv" 21 "unicode" 22 23 "github.com/pingcap/errors" 24 "github.com/pingcap/tidb/parser/ast" 25 "github.com/pingcap/tidb/parser/auth" 26 "github.com/pingcap/tidb/parser/charset" 27 "github.com/pingcap/tidb/parser/mysql" 28 "github.com/pingcap/tidb/parser/terror" 29 "github.com/pingcap/tidb/parser/types" 30 ) 31 32 var ( 33 // ErrSyntax returns for sql syntax error. 34 ErrSyntax = terror.ClassParser.NewStd(mysql.ErrSyntax) 35 // ErrParse returns for sql parse error. 36 ErrParse = terror.ClassParser.NewStd(mysql.ErrParse) 37 // ErrUnknownCharacterSet returns for no character set found error. 38 ErrUnknownCharacterSet = terror.ClassParser.NewStd(mysql.ErrUnknownCharacterSet) 39 // ErrInvalidYearColumnLength returns for illegal column length for year type. 40 ErrInvalidYearColumnLength = terror.ClassParser.NewStd(mysql.ErrInvalidYearColumnLength) 41 // ErrWrongArguments returns for illegal argument. 42 ErrWrongArguments = terror.ClassParser.NewStd(mysql.ErrWrongArguments) 43 // ErrWrongFieldTerminators returns for illegal field terminators. 44 ErrWrongFieldTerminators = terror.ClassParser.NewStd(mysql.ErrWrongFieldTerminators) 45 // ErrTooBigDisplayWidth returns for data display width exceed limit . 46 ErrTooBigDisplayWidth = terror.ClassParser.NewStd(mysql.ErrTooBigDisplaywidth) 47 // ErrTooBigPrecision returns for data precision exceed limit. 48 ErrTooBigPrecision = terror.ClassParser.NewStd(mysql.ErrTooBigPrecision) 49 // ErrUnknownAlterLock returns for no alter lock type found error. 50 ErrUnknownAlterLock = terror.ClassParser.NewStd(mysql.ErrUnknownAlterLock) 51 // ErrUnknownAlterAlgorithm returns for no alter algorithm found error. 52 ErrUnknownAlterAlgorithm = terror.ClassParser.NewStd(mysql.ErrUnknownAlterAlgorithm) 53 // ErrWrongValue returns for wrong value 54 ErrWrongValue = terror.ClassParser.NewStd(mysql.ErrWrongValue) 55 // ErrWarnDeprecatedSyntax return when the syntax was deprecated 56 ErrWarnDeprecatedSyntax = terror.ClassParser.NewStd(mysql.ErrWarnDeprecatedSyntax) 57 // ErrWarnDeprecatedSyntaxNoReplacement return when the syntax was deprecated and there is no replacement. 58 ErrWarnDeprecatedSyntaxNoReplacement = terror.ClassParser.NewStd(mysql.ErrWarnDeprecatedSyntaxNoReplacement) 59 // ErrWrongUsage returns for incorrect usages. 60 ErrWrongUsage = terror.ClassParser.NewStd(mysql.ErrWrongUsage) 61 // SpecFieldPattern special result field pattern 62 SpecFieldPattern = regexp.MustCompile(`(\/\*!(M?[0-9]{5,6})?|\*\/)`) 63 specCodeStart = regexp.MustCompile(`^\/\*!(M?[0-9]{5,6})?[ \t]*`) 64 specCodeEnd = regexp.MustCompile(`[ \t]*\*\/$`) 65 ) 66 67 // TrimComment trim comment for special comment code of MySQL. 68 func TrimComment(txt string) string { 69 txt = specCodeStart.ReplaceAllString(txt, "") 70 return specCodeEnd.ReplaceAllString(txt, "") 71 } 72 73 //revive:disable:exported 74 75 // ParserConfig is the parser config. 76 type ParserConfig struct { 77 EnableWindowFunction bool 78 EnableStrictDoubleTypeCheck bool 79 SkipPositionRecording bool 80 } 81 82 //revive:enable:exported 83 84 // Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function. 85 type Parser struct { 86 charset string 87 collation string 88 result []ast.StmtNode 89 src string 90 lexer Scanner 91 hintParser *hintParser 92 93 explicitCharset bool 94 strictDoubleFieldType bool 95 96 // the following fields are used by yyParse to reduce allocation. 97 cache []yySymType 98 yylval yySymType 99 yyVAL *yySymType 100 } 101 102 func yySetOffset(yyVAL *yySymType, offset int) { 103 if yyVAL.expr != nil { 104 yyVAL.expr.SetOriginTextPosition(offset) 105 } 106 } 107 108 func yyhintSetOffset(_ *yyhintSymType, _ int) { 109 } 110 111 type stmtTexter interface { 112 stmtText() string 113 } 114 115 // New returns a Parser object with default SQL mode. 116 func New() *Parser { 117 if ast.NewValueExpr == nil || 118 ast.NewParamMarkerExpr == nil || 119 ast.NewHexLiteral == nil || 120 ast.NewBitLiteral == nil { 121 panic("no parser driver (forgotten import?) https://github.com/pingcap/parser/issues/43") 122 } 123 124 p := &Parser{ 125 cache: make([]yySymType, 200), 126 } 127 p.EnableWindowFunc(true) 128 p.SetStrictDoubleTypeCheck(true) 129 mode, _ := mysql.GetSQLMode(mysql.DefaultSQLMode) 130 p.SetSQLMode(mode) 131 return p 132 } 133 134 // SetStrictDoubleTypeCheck enables/disables strict double type check. 135 func (parser *Parser) SetStrictDoubleTypeCheck(val bool) { 136 parser.strictDoubleFieldType = val 137 } 138 139 // SetParserConfig sets the parser config. 140 func (parser *Parser) SetParserConfig(config ParserConfig) { 141 parser.EnableWindowFunc(config.EnableWindowFunction) 142 parser.SetStrictDoubleTypeCheck(config.EnableStrictDoubleTypeCheck) 143 parser.lexer.skipPositionRecording = config.SkipPositionRecording 144 } 145 146 // ParseSQL parses a query string to raw ast.StmtNode. 147 func (parser *Parser) ParseSQL(sql string, params ...ParseParam) (stmt []ast.StmtNode, warns []error, err error) { 148 resetParams(parser) 149 parser.lexer.reset(sql) 150 for _, p := range params { 151 if err := p.ApplyOn(parser); err != nil { 152 return nil, nil, err 153 } 154 } 155 parser.src = sql 156 parser.result = parser.result[:0] 157 158 var l yyLexer = &parser.lexer 159 yyParse(l, parser) 160 161 warns, errs := l.Errors() 162 if len(warns) > 0 { 163 warns = append([]error(nil), warns...) 164 } else { 165 warns = nil 166 } 167 if len(errs) != 0 { 168 return nil, warns, errors.Trace(errs[0]) 169 } 170 for _, stmt := range parser.result { 171 ast.SetFlag(stmt) 172 } 173 return parser.result, warns, nil 174 } 175 176 // Parse parses a query string to raw ast.StmtNode. 177 // If charset or collation is "", default charset and collation will be used. 178 func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) { 179 return parser.ParseSQL(sql, CharsetConnection(charset), CollationConnection(collation)) 180 } 181 182 func (parser *Parser) lastErrorAsWarn() { 183 parser.lexer.lastErrorAsWarn() 184 } 185 186 // ParseOneStmt parses a query and returns an ast.StmtNode. 187 // The query must have one statement, otherwise ErrSyntax is returned. 188 func (parser *Parser) ParseOneStmt(sql, charset, collation string) (ast.StmtNode, error) { 189 stmts, _, err := parser.ParseSQL(sql, CharsetConnection(charset), CollationConnection(collation)) 190 if err != nil { 191 return nil, errors.Trace(err) 192 } 193 if len(stmts) != 1 { 194 return nil, ErrSyntax 195 } 196 ast.SetFlag(stmts[0]) 197 return stmts[0], nil 198 } 199 200 // SetSQLMode sets the SQL mode for parser. 201 func (parser *Parser) SetSQLMode(mode mysql.SQLMode) { 202 parser.lexer.SetSQLMode(mode) 203 } 204 205 // EnableWindowFunc controls whether the parser to parse syntax related with window function. 206 func (parser *Parser) EnableWindowFunc(val bool) { 207 parser.lexer.EnableWindowFunc(val) 208 } 209 210 // ParseErrorWith returns "You have a syntax error near..." error message compatible with mysql. 211 func ParseErrorWith(errstr string, lineno int) error { 212 if len(errstr) > mysql.ErrTextLength { 213 errstr = errstr[:mysql.ErrTextLength] 214 } 215 return fmt.Errorf("near '%-.80s' at line %d", errstr, lineno) 216 } 217 218 // The select statement is not at the end of the whole statement, if the last 219 // field text was set from its offset to the end of the src string, update 220 // the last field text. 221 func (parser *Parser) setLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { 222 if st.Kind != ast.SelectStmtKindSelect { 223 return 224 } 225 lastField := st.Fields.Fields[len(st.Fields.Fields)-1] 226 if lastField.Offset+len(lastField.OriginalText()) >= len(parser.src)-1 { 227 lastField.SetText(parser.lexer.client, parser.src[lastField.Offset:lastEnd]) 228 } 229 } 230 231 func (*Parser) startOffset(v *yySymType) int { 232 return v.offset 233 } 234 235 func (parser *Parser) endOffset(v *yySymType) int { 236 offset := v.offset 237 for offset > 0 && unicode.IsSpace(rune(parser.src[offset-1])) { 238 offset-- 239 } 240 return offset 241 } 242 243 func (parser *Parser) parseHint(input string) ([]*ast.TableOptimizerHint, []error) { 244 if parser.hintParser == nil { 245 parser.hintParser = newHintParser() 246 } 247 return parser.hintParser.parse(input, parser.lexer.GetSQLMode(), parser.lexer.lastHintPos) 248 } 249 250 func toInt(l yyLexer, lval *yySymType, str string) int { 251 n, err := strconv.ParseUint(str, 10, 64) 252 if err != nil { 253 e := err.(*strconv.NumError) 254 if e.Err == strconv.ErrRange { 255 // TODO: toDecimal maybe out of range still. 256 // This kind of error should be throw to higher level, because truncated data maybe legal. 257 // For example, this SQL returns error: 258 // create table test (id decimal(30, 0)); 259 // insert into test values(123456789012345678901234567890123094839045793405723406801943850); 260 // While this SQL: 261 // select 1234567890123456789012345678901230948390457934057234068019438509023041874359081325875128590860234789847359871045943057; 262 // get value 99999999999999999999999999999999999999999999999999999999999999999 263 return toDecimal(l, lval, str) 264 } 265 l.AppendError(l.Errorf("integer literal: %v", err)) 266 return invalid 267 } 268 269 switch { 270 case n <= math.MaxInt64: 271 lval.item = int64(n) 272 default: 273 lval.item = n 274 } 275 return intLit 276 } 277 278 func toDecimal(l yyLexer, lval *yySymType, str string) int { 279 dec, err := ast.NewDecimal(str) 280 if err != nil { 281 if terror.ErrorEqual(err, types.ErrDataOutOfRange) { 282 l.AppendWarn(types.ErrTruncatedWrongValue.FastGenByArgs("DECIMAL", dec)) 283 dec, _ = ast.NewDecimal(mysql.DefaultDecimal) 284 } else { 285 l.AppendError(l.Errorf("decimal literal: %v", err)) 286 } 287 } 288 lval.item = dec 289 return decLit 290 } 291 292 func toFloat(l yyLexer, lval *yySymType, str string) int { 293 n, err := strconv.ParseFloat(str, 64) 294 if err != nil { 295 e := err.(*strconv.NumError) 296 if e.Err == strconv.ErrRange { 297 l.AppendError(types.ErrIllegalValueForType.GenWithStackByArgs("double", str)) 298 return invalid 299 } 300 l.AppendError(l.Errorf("float literal: %v", err)) 301 return invalid 302 } 303 304 lval.item = n 305 return floatLit 306 } 307 308 // See https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html 309 func toHex(l yyLexer, lval *yySymType, str string) int { 310 h, err := ast.NewHexLiteral(str) 311 if err != nil { 312 l.AppendError(l.Errorf("hex literal: %v", err)) 313 return invalid 314 } 315 lval.item = h 316 return hexLit 317 } 318 319 // See https://dev.mysql.com/doc/refman/5.7/en/bit-type.html 320 func toBit(l yyLexer, lval *yySymType, str string) int { 321 b, err := ast.NewBitLiteral(str) 322 if err != nil { 323 l.AppendError(l.Errorf("bit literal: %v", err)) 324 return invalid 325 } 326 lval.item = b 327 return bitLit 328 } 329 330 func getUint64FromNUM(num interface{}) uint64 { 331 switch v := num.(type) { 332 case int64: 333 return uint64(v) 334 case uint64: 335 return v 336 } 337 return 0 338 } 339 340 func getInt64FromNUM(num interface{}) (val int64, errMsg string) { 341 switch v := num.(type) { 342 case int64: 343 return v, "" 344 default: 345 return -1, fmt.Sprintf("%d is out of range [–9223372036854775808,9223372036854775807]", num) 346 } 347 } 348 349 func isRevokeAllGrant(roleOrPrivList []*ast.RoleOrPriv) bool { 350 if len(roleOrPrivList) != 2 { 351 return false 352 } 353 priv, err := roleOrPrivList[0].ToPriv() 354 if err != nil { 355 return false 356 } 357 if priv.Priv != mysql.AllPriv { 358 return false 359 } 360 priv, err = roleOrPrivList[1].ToPriv() 361 if err != nil { 362 return false 363 } 364 if priv.Priv != mysql.GrantPriv { 365 return false 366 } 367 return true 368 } 369 370 // convertToRole tries to convert elements of roleOrPrivList to RoleIdentity 371 func convertToRole(roleOrPrivList []*ast.RoleOrPriv) ([]*auth.RoleIdentity, error) { 372 var roles []*auth.RoleIdentity 373 for _, elem := range roleOrPrivList { 374 role, err := elem.ToRole() 375 if err != nil { 376 return nil, err 377 } 378 roles = append(roles, role) 379 } 380 return roles, nil 381 } 382 383 // convertToPriv tries to convert elements of roleOrPrivList to PrivElem 384 func convertToPriv(roleOrPrivList []*ast.RoleOrPriv) ([]*ast.PrivElem, error) { 385 var privileges []*ast.PrivElem 386 for _, elem := range roleOrPrivList { 387 priv, err := elem.ToPriv() 388 if err != nil { 389 return nil, err 390 } 391 privileges = append(privileges, priv) 392 } 393 return privileges, nil 394 } 395 396 var ( 397 _ ParseParam = CharsetConnection("") 398 _ ParseParam = CollationConnection("") 399 _ ParseParam = CharsetClient("") 400 ) 401 402 func resetParams(p *Parser) { 403 p.charset = mysql.DefaultCharset 404 p.collation = mysql.DefaultCollationName 405 } 406 407 // ParseParam represents the parameter of parsing. 408 type ParseParam interface { 409 ApplyOn(*Parser) error 410 } 411 412 // CharsetConnection is used for literals specified without a character set. 413 type CharsetConnection string 414 415 // ApplyOn implements ParseParam interface. 416 func (c CharsetConnection) ApplyOn(p *Parser) error { 417 if c == "" { 418 p.charset = mysql.DefaultCharset 419 } else { 420 p.charset = string(c) 421 } 422 p.lexer.connection = charset.FindEncoding(string(c)) 423 return nil 424 } 425 426 // CollationConnection is used for literals specified without a collation. 427 type CollationConnection string 428 429 // ApplyOn implements ParseParam interface. 430 func (c CollationConnection) ApplyOn(p *Parser) error { 431 if c == "" { 432 p.collation = mysql.DefaultCollationName 433 } else { 434 p.collation = string(c) 435 } 436 return nil 437 } 438 439 // CharsetClient specifies the charset of a SQL. 440 // This is used to decode the SQL into a utf-8 string. 441 type CharsetClient string 442 443 // ApplyOn implements ParseParam interface. 444 func (c CharsetClient) ApplyOn(p *Parser) error { 445 p.lexer.client = charset.FindEncoding(string(c)) 446 return nil 447 }