github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/parser/parse.go (about) 1 // Copyright 2012, Google Inc. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in licenses/BSD-vitess.txt. 4 5 // Portions of this file are additionally subject to the following 6 // license and copyright. 7 // 8 // Copyright 2015 The Cockroach Authors. 9 // 10 // Use of this software is governed by the Business Source License 11 // included in the file licenses/BSL.txt. 12 // 13 // As of the Change Date specified in that file, in accordance with 14 // the Business Source License, use of this software will be governed 15 // by the Apache License, Version 2.0, included in the file 16 // licenses/APL.txt. 17 18 // This code was derived from https://github.com/youtube/vitess. 19 20 package parser 21 22 import ( 23 "fmt" 24 "go/constant" 25 "strings" 26 27 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser/statements" 28 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode" 29 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror" 30 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/scanner" 31 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" 32 "github.com/cockroachdb/cockroachdb-parser/pkg/sql/types" 33 "github.com/cockroachdb/errors" 34 ) 35 36 func init() { 37 scanner.NewNumValFn = func(a constant.Value, s string, b bool) interface{} { return tree.NewNumVal(a, s, b) } 38 scanner.NewPlaceholderFn = func(s string) (interface{}, error) { return tree.NewPlaceholder(s) } 39 } 40 41 // Parser wraps a scanner, parser and other utilities present in the parser 42 // package. 43 type Parser struct { 44 scanner scanner.SQLScanner 45 lexer lexer 46 parserImpl sqlParserImpl 47 tokBuf [8]sqlSymType 48 stmtBuf [1]statements.Statement[tree.Statement] 49 } 50 51 // INT8 is the historical interpretation of INT. This should be left 52 // alone in the future, since there are many sql fragments stored 53 // in various descriptors. Any user input that was created after 54 // INT := INT4 will simply use INT4 in any resulting code. 55 var defaultNakedIntType = types.Int 56 57 // NakedIntTypeFromDefaultIntSize given the size in bits or bytes (preferred) 58 // of how a "naked" INT type should be parsed returns the corresponding integer 59 // type. 60 func NakedIntTypeFromDefaultIntSize(defaultIntSize int32) *types.T { 61 switch defaultIntSize { 62 case 4, 32: 63 return types.Int4 64 default: 65 return types.Int 66 } 67 } 68 69 // Parse parses the sql and returns a list of statements. 70 func (p *Parser) Parse(sql string) (statements.Statements, error) { 71 return p.parseWithDepth(1, sql, defaultNakedIntType) 72 } 73 74 // ParseWithInt parses a sql statement string and returns a list of 75 // Statements. The INT token will result in the specified TInt type. 76 func (p *Parser) ParseWithInt(sql string, nakedIntType *types.T) (statements.Statements, error) { 77 return p.parseWithDepth(1, sql, nakedIntType) 78 } 79 80 func (p *Parser) parseOneWithInt( 81 sql string, nakedIntType *types.T, 82 ) (statements.Statement[tree.Statement], error) { 83 stmts, err := p.parseWithDepth(1, sql, nakedIntType) 84 if err != nil { 85 return statements.Statement[tree.Statement]{}, err 86 } 87 if len(stmts) != 1 { 88 return statements.Statement[tree.Statement]{}, errors.AssertionFailedf("expected 1 statement, but found %d", len(stmts)) 89 } 90 return stmts[0], nil 91 } 92 93 func (p *Parser) scanOneStmt() (sql string, tokens []sqlSymType, done bool) { 94 tokens = p.tokBuf[:0] 95 tokens = append(tokens, sqlSymType{}) 96 lval := &p.tokBuf[0] 97 98 // Scan the first token. 99 for { 100 p.scanner.Scan(lval) 101 if lval.id == 0 { 102 return "", nil, true 103 } 104 if lval.id != ';' { 105 break 106 } 107 } 108 109 startPos := lval.pos 110 // We make the resulting token positions match the returned string. 111 lval.pos = 0 112 var preValID int32 113 // This is used to track the degree of nested `BEGIN ATOMIC ... END` function 114 // body context. When greater than zero, it means that we're scanning through 115 // the function body of a `CREATE FUNCTION` statement. ';' character is only 116 // a separator of sql statements within the body instead of a finishing line 117 // of the `CREATE FUNCTION` statement. 118 curFuncBodyCnt := 0 119 for { 120 if lval.id == ERROR { 121 return p.scanner.In()[startPos:], tokens, true 122 } 123 preValID = lval.id 124 tokens = append(tokens, sqlSymType{}) 125 lval = &tokens[len(tokens)-1] 126 p.scanner.Scan(lval) 127 128 if preValID == BEGIN && lval.id == ATOMIC { 129 curFuncBodyCnt++ 130 } 131 if curFuncBodyCnt > 0 && lval.id == END { 132 curFuncBodyCnt-- 133 } 134 if lval.id == 0 || (curFuncBodyCnt == 0 && lval.id == ';') { 135 endPos := p.scanner.Pos() 136 if lval.id == ';' { 137 // Don't include the ending semicolon, if there is one, in the raw SQL. 138 endPos-- 139 } 140 tokens = tokens[:len(tokens)-1] 141 return p.scanner.In()[startPos:endPos], tokens, (lval.id == 0) 142 } 143 lval.pos -= startPos 144 } 145 } 146 147 func (p *Parser) parseWithDepth( 148 depth int, sql string, nakedIntType *types.T, 149 ) (statements.Statements, error) { 150 stmts := statements.Statements(p.stmtBuf[:0]) 151 p.scanner.Init(sql) 152 defer p.scanner.Cleanup() 153 for { 154 sql, tokens, done := p.scanOneStmt() 155 stmt, err := p.parse(depth+1, sql, tokens, nakedIntType) 156 if err != nil { 157 return nil, err 158 } 159 if stmt.AST != nil { 160 stmts = append(stmts, stmt) 161 } 162 if done { 163 break 164 } 165 } 166 return stmts, nil 167 } 168 169 // parse parses a statement from the given scanned tokens. 170 func (p *Parser) parse( 171 depth int, sql string, tokens []sqlSymType, nakedIntType *types.T, 172 ) (statements.Statement[tree.Statement], error) { 173 p.lexer.init(sql, tokens, nakedIntType) 174 defer p.lexer.cleanup() 175 if p.parserImpl.Parse(&p.lexer) != 0 { 176 if p.lexer.lastError == nil { 177 // This should never happen -- there should be an error object 178 // every time Parse() returns nonzero. We're just playing safe 179 // here. 180 p.lexer.Error("syntax error") 181 } 182 err := p.lexer.lastError 183 184 // Compatibility with 19.1 telemetry: prefix the telemetry keys 185 // with the "syntax." prefix. 186 // TODO(knz): move the auto-prefixing of feature names to a 187 // higher level in the call stack. 188 tkeys := errors.GetTelemetryKeys(err) 189 if len(tkeys) > 0 { 190 for i := range tkeys { 191 tkeys[i] = "syntax." + tkeys[i] 192 } 193 err = errors.WithTelemetry(err, tkeys...) 194 } 195 196 return statements.Statement[tree.Statement]{}, err 197 } 198 199 return statements.Statement[tree.Statement]{ 200 AST: p.lexer.stmt, 201 SQL: sql, 202 Comments: p.scanner.Comments, 203 NumPlaceholders: p.lexer.numPlaceholders, 204 NumAnnotations: p.lexer.numAnnotations, 205 }, nil 206 } 207 208 // unaryNegation constructs an AST node for a negation. This attempts 209 // to preserve constant NumVals and embed the negative sign inside 210 // them instead of wrapping in an UnaryExpr. This in turn ensures 211 // that negative numbers get considered as a single constant 212 // for the purpose of formatting and scrubbing. 213 func unaryNegation(e tree.Expr) tree.Expr { 214 if cst, ok := e.(*tree.NumVal); ok { 215 cst.Negate() 216 return cst 217 } 218 219 // Common case. 220 return &tree.UnaryExpr{ 221 Operator: tree.MakeUnaryOperator(tree.UnaryMinus), 222 Expr: e, 223 } 224 } 225 226 // Parse parses a sql statement string and returns a list of Statements. 227 func Parse(sql string) (statements.Statements, error) { 228 return ParseWithInt(sql, defaultNakedIntType) 229 } 230 231 // ParseWithInt parses a sql statement string and returns a list of 232 // Statements. The INT token will result in the specified TInt type. 233 func ParseWithInt(sql string, nakedIntType *types.T) (statements.Statements, error) { 234 var p Parser 235 return p.parseWithDepth(1, sql, nakedIntType) 236 } 237 238 // ParseOne parses a sql statement string, ensuring that it contains only a 239 // single statement, and returns that Statement. ParseOne will always 240 // interpret the INT and SERIAL types as 64-bit types, since this is 241 // used in various internal-execution paths where we might receive 242 // bits of SQL from other nodes. In general,earwe expect that all 243 // user-generated SQL has been run through the ParseWithInt() function. 244 func ParseOne(sql string) (statements.Statement[tree.Statement], error) { 245 return ParseOneWithInt(sql, defaultNakedIntType) 246 } 247 248 // ParseOneWithInt is similar to ParseOn but interprets the INT and SERIAL 249 // types as the provided integer type. 250 func ParseOneWithInt( 251 sql string, nakedIntType *types.T, 252 ) (statements.Statement[tree.Statement], error) { 253 var p Parser 254 return p.parseOneWithInt(sql, nakedIntType) 255 } 256 257 // ParseQualifiedTableName parses a possibly qualified table name. The 258 // table name must contain one or more name parts, using the full 259 // input SQL syntax: each name part containing special characters, or 260 // non-lowercase characters, must be enclosed in double quote. The 261 // name may not be an invalid table name (the caller is responsible 262 // for guaranteeing that only valid table names are provided as 263 // input). 264 func ParseQualifiedTableName(sql string) (*tree.TableName, error) { 265 name, err := ParseTableName(sql) 266 if err != nil { 267 return nil, err 268 } 269 tn := name.ToTableName() 270 return &tn, nil 271 } 272 273 // ParseTableName parses a table name. The table name must contain one 274 // or more name parts, using the full input SQL syntax: each name 275 // part containing special characters, or non-lowercase characters, 276 // must be enclosed in double quote. The name may not be an invalid 277 // table name (the caller is responsible for guaranteeing that only 278 // valid table names are provided as input). 279 func ParseTableName(sql string) (*tree.UnresolvedObjectName, error) { 280 // We wrap the name we want to parse into a dummy statement since our parser 281 // can only parse full statements. 282 stmt, err := ParseOne(fmt.Sprintf("ALTER TABLE %s RENAME TO x", sql)) 283 if err != nil { 284 return nil, err 285 } 286 rename, ok := stmt.AST.(*tree.RenameTable) 287 if !ok { 288 return nil, errors.AssertionFailedf("expected an ALTER TABLE statement, but found %T", stmt) 289 } 290 return rename.Name, nil 291 } 292 293 // ParseTablePattern parses a table pattern. The table name must contain one 294 // or more name parts, using the full input SQL syntax: each name 295 // part containing special characters, or non-lowercase characters, 296 // must be enclosed in double quote. The name may not be an invalid 297 // table name (the caller is responsible for guaranteeing that only 298 // valid table names are provided as input). 299 // The last part may be '*' to denote a wildcard. 300 func ParseTablePattern(sql string) (tree.TablePattern, error) { 301 // We wrap the name we want to parse into a dummy statement since our parser 302 // can only parse full statements. 303 stmt, err := ParseOne(fmt.Sprintf("GRANT SELECT ON TABLE %s TO admin", sql)) 304 if err != nil { 305 return nil, err 306 } 307 grant, ok := stmt.AST.(*tree.Grant) 308 if !ok { 309 return nil, errors.AssertionFailedf("expected a GRANT statement, but found %T", stmt) 310 } 311 if len(grant.Targets.Tables.TablePatterns) == 0 { 312 return nil, errors.AssertionFailedf("expected at least one pattern") 313 } 314 u := grant.Targets.Tables.TablePatterns[0] 315 un, ok := u.(*tree.UnresolvedName) 316 if !ok { 317 return nil, errors.AssertionFailedf("expected an unresolved name, but found %T", u) 318 } 319 return un.NormalizeTablePattern() 320 } 321 322 // parseExprsWithInt parses one or more sql expressions. 323 func parseExprsWithInt(exprs []string, nakedIntType *types.T) (tree.Exprs, error) { 324 stmt, err := ParseOneWithInt(fmt.Sprintf("SET ROW (%s)", strings.Join(exprs, ",")), nakedIntType) 325 if err != nil { 326 return nil, err 327 } 328 set, ok := stmt.AST.(*tree.SetVar) 329 if !ok { 330 return nil, errors.AssertionFailedf("expected a SET statement, but found %T", stmt) 331 } 332 return set.Values, nil 333 } 334 335 // ParseExprs parses a comma-delimited sequence of SQL scalar 336 // expressions. The caller is responsible for ensuring that the input 337 // is, in fact, a comma-delimited sequence of SQL scalar expressions — 338 // the results are undefined if the string contains invalid SQL 339 // syntax. 340 func ParseExprs(sql []string) (tree.Exprs, error) { 341 if len(sql) == 0 { 342 return tree.Exprs{}, nil 343 } 344 return parseExprsWithInt(sql, defaultNakedIntType) 345 } 346 347 // ParseExpr parses a SQL scalar expression. The caller is responsible 348 // for ensuring that the input is, in fact, a valid SQL scalar 349 // expression — the results are undefined if the string contains 350 // invalid SQL syntax. 351 func ParseExpr(sql string) (tree.Expr, error) { 352 return ParseExprWithInt(sql, defaultNakedIntType) 353 } 354 355 // ParseExprWithInt parses a SQL scalar expression, using the given 356 // type when INT is used as type name in the SQL syntax. The caller is 357 // responsible for ensuring that the input is, in fact, a valid SQL 358 // scalar expression — the results are undefined if the string 359 // contains invalid SQL syntax. 360 func ParseExprWithInt(sql string, nakedIntType *types.T) (tree.Expr, error) { 361 exprs, err := parseExprsWithInt([]string{sql}, nakedIntType) 362 if err != nil { 363 return nil, err 364 } 365 if len(exprs) != 1 { 366 return nil, errors.AssertionFailedf("expected 1 expression, found %d", len(exprs)) 367 } 368 return exprs[0], nil 369 } 370 371 // GetTypeReferenceFromName turns a type name into a type 372 // reference. This supports only “simple” (single-identifier) 373 // references to built-in types, when the identifer has already been 374 // parsed away from the input SQL syntax. 375 func GetTypeReferenceFromName(typeName tree.Name) (tree.ResolvableTypeReference, error) { 376 expr, err := ParseExpr(fmt.Sprintf("1::%s", typeName.String())) 377 if err != nil { 378 return nil, err 379 } 380 381 cast, ok := expr.(*tree.CastExpr) 382 if !ok { 383 return nil, errors.AssertionFailedf("expected a tree.CastExpr, but found %T", expr) 384 } 385 386 return cast.Type, nil 387 } 388 389 // GetTypeFromValidSQLSyntax retrieves a type from its SQL syntax. The caller is 390 // responsible for guaranteeing that the type expression is valid 391 // SQL (or handling the resulting error). This includes verifying that complex 392 // identifiers are enclosed in double quotes, etc. 393 func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error) { 394 expr, err := ParseExpr(fmt.Sprintf("1::%s", sql)) 395 if err != nil { 396 return nil, err 397 } 398 return GetTypeFromCastOrCollate(expr) 399 } 400 401 // GetTypeFromCastOrCollate returns the type of the given tree.Expr. The method 402 // assumes that the expression is either tree.CastExpr or tree.CollateExpr 403 // (which wraps the tree.CastExpr). 404 func GetTypeFromCastOrCollate(expr tree.Expr) (tree.ResolvableTypeReference, error) { 405 // COLLATE clause has lower precedence than the cast, so if we have 406 // something like `1::STRING COLLATE en`, it'll be parsed as 407 // CollateExpr(CastExpr). 408 if collate, ok := expr.(*tree.CollateExpr); ok { 409 return types.MakeCollatedString(types.String, collate.Locale), nil 410 } 411 412 cast, ok := expr.(*tree.CastExpr) 413 if !ok { 414 return nil, errors.AssertionFailedf("expected a tree.CastExpr, but found %T", expr) 415 } 416 417 return cast.Type, nil 418 } 419 420 var errBitLengthNotPositive = pgerror.WithCandidateCode( 421 errors.New("length for type bit must be at least 1"), pgcode.InvalidParameterValue) 422 423 // newBitType creates a new BIT type with the given bit width. 424 func newBitType(width int32, varying bool) (*types.T, error) { 425 if width < 1 { 426 return nil, errBitLengthNotPositive 427 } 428 if varying { 429 return types.MakeVarBit(width), nil 430 } 431 return types.MakeBit(width), nil 432 } 433 434 var errFloatPrecAtLeast1 = pgerror.WithCandidateCode( 435 errors.New("precision for type float must be at least 1 bit"), pgcode.InvalidParameterValue) 436 var errFloatPrecMax54 = pgerror.WithCandidateCode( 437 errors.New("precision for type float must be less than 54 bits"), pgcode.InvalidParameterValue) 438 439 // newFloat creates a type for FLOAT with the given precision. 440 func newFloat(prec int64) (*types.T, error) { 441 if prec < 1 { 442 return nil, errFloatPrecAtLeast1 443 } 444 if prec <= 24 { 445 return types.Float4, nil 446 } 447 if prec <= 54 { 448 return types.Float, nil 449 } 450 return nil, errFloatPrecMax54 451 } 452 453 // newDecimal creates a type for DECIMAL with the given precision and scale. 454 func newDecimal(prec, scale int32) (*types.T, error) { 455 if scale > prec { 456 err := pgerror.WithCandidateCode( 457 errors.Newf("scale (%d) must be between 0 and precision (%d)", scale, prec), 458 pgcode.InvalidParameterValue) 459 return nil, err 460 } 461 return types.MakeDecimal(prec, scale), nil 462 } 463 464 // arrayOf creates a type alias for an array of the given element type and fixed 465 // bounds. The bounds are currently ignored. 466 func arrayOf( 467 ref tree.ResolvableTypeReference, bounds []int32, 468 ) (tree.ResolvableTypeReference, error) { 469 // If the reference is a statically known type, then return an array type, 470 // rather than an array type reference. 471 if typ, ok := tree.GetStaticallyKnownType(ref); ok { 472 // Do not allow type unknown[]. This is consistent with Postgres' behavior. 473 if typ.Family() == types.UnknownFamily { 474 return nil, pgerror.Newf(pgcode.UndefinedObject, "type unknown[] does not exist") 475 } 476 if typ.Family() == types.VoidFamily { 477 return nil, pgerror.Newf(pgcode.UndefinedObject, "type void[] does not exist") 478 } 479 if err := types.CheckArrayElementType(typ); err != nil { 480 return nil, err 481 } 482 return types.MakeArray(typ), nil 483 } 484 return &tree.ArrayTypeReference{ElementType: ref}, nil 485 }