github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/parser/parse.go (about)

     1  // Copyright 2012, Google Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in licenses/BSD-vitess.txt.
     4  
     5  // Portions of this file are additionally subject to the following
     6  // license and copyright.
     7  //
     8  // Copyright 2015 The Cockroach Authors.
     9  //
    10  // Use of this software is governed by the Business Source License
    11  // included in the file licenses/BSL.txt.
    12  //
    13  // As of the Change Date specified in that file, in accordance with
    14  // the Business Source License, use of this software will be governed
    15  // by the Apache License, Version 2.0, included in the file
    16  // licenses/APL.txt.
    17  
    18  // This code was derived from https://github.com/youtube/vitess.
    19  
    20  package parser
    21  
    22  import (
    23  	"fmt"
    24  	"go/constant"
    25  	"strings"
    26  
    27  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser/statements"
    28  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgcode"
    29  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/pgwire/pgerror"
    30  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/scanner"
    31  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
    32  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/types"
    33  	"github.com/cockroachdb/errors"
    34  )
    35  
    36  func init() {
    37  	scanner.NewNumValFn = func(a constant.Value, s string, b bool) interface{} { return tree.NewNumVal(a, s, b) }
    38  	scanner.NewPlaceholderFn = func(s string) (interface{}, error) { return tree.NewPlaceholder(s) }
    39  }
    40  
    41  // Parser wraps a scanner, parser and other utilities present in the parser
    42  // package.
    43  type Parser struct {
    44  	scanner    scanner.SQLScanner
    45  	lexer      lexer
    46  	parserImpl sqlParserImpl
    47  	tokBuf     [8]sqlSymType
    48  	stmtBuf    [1]statements.Statement[tree.Statement]
    49  }
    50  
    51  // INT8 is the historical interpretation of INT. This should be left
    52  // alone in the future, since there are many sql fragments stored
    53  // in various descriptors. Any user input that was created after
    54  // INT := INT4 will simply use INT4 in any resulting code.
    55  var defaultNakedIntType = types.Int
    56  
    57  // NakedIntTypeFromDefaultIntSize given the size in bits or bytes (preferred)
    58  // of how a "naked" INT type should be parsed returns the corresponding integer
    59  // type.
    60  func NakedIntTypeFromDefaultIntSize(defaultIntSize int32) *types.T {
    61  	switch defaultIntSize {
    62  	case 4, 32:
    63  		return types.Int4
    64  	default:
    65  		return types.Int
    66  	}
    67  }
    68  
    69  // Parse parses the sql and returns a list of statements.
    70  func (p *Parser) Parse(sql string) (statements.Statements, error) {
    71  	return p.parseWithDepth(1, sql, defaultNakedIntType)
    72  }
    73  
    74  // ParseWithInt parses a sql statement string and returns a list of
    75  // Statements. The INT token will result in the specified TInt type.
    76  func (p *Parser) ParseWithInt(sql string, nakedIntType *types.T) (statements.Statements, error) {
    77  	return p.parseWithDepth(1, sql, nakedIntType)
    78  }
    79  
    80  func (p *Parser) parseOneWithInt(
    81  	sql string, nakedIntType *types.T,
    82  ) (statements.Statement[tree.Statement], error) {
    83  	stmts, err := p.parseWithDepth(1, sql, nakedIntType)
    84  	if err != nil {
    85  		return statements.Statement[tree.Statement]{}, err
    86  	}
    87  	if len(stmts) != 1 {
    88  		return statements.Statement[tree.Statement]{}, errors.AssertionFailedf("expected 1 statement, but found %d", len(stmts))
    89  	}
    90  	return stmts[0], nil
    91  }
    92  
    93  func (p *Parser) scanOneStmt() (sql string, tokens []sqlSymType, done bool) {
    94  	tokens = p.tokBuf[:0]
    95  	tokens = append(tokens, sqlSymType{})
    96  	lval := &p.tokBuf[0]
    97  
    98  	// Scan the first token.
    99  	for {
   100  		p.scanner.Scan(lval)
   101  		if lval.id == 0 {
   102  			return "", nil, true
   103  		}
   104  		if lval.id != ';' {
   105  			break
   106  		}
   107  	}
   108  
   109  	startPos := lval.pos
   110  	// We make the resulting token positions match the returned string.
   111  	lval.pos = 0
   112  	var preValID int32
   113  	// This is used to track the degree of nested `BEGIN ATOMIC ... END` function
   114  	// body context. When greater than zero, it means that we're scanning through
   115  	// the function body of a `CREATE FUNCTION` statement. ';' character is only
   116  	// a separator of sql statements within the body instead of a finishing line
   117  	// of the `CREATE FUNCTION` statement.
   118  	curFuncBodyCnt := 0
   119  	for {
   120  		if lval.id == ERROR {
   121  			return p.scanner.In()[startPos:], tokens, true
   122  		}
   123  		preValID = lval.id
   124  		tokens = append(tokens, sqlSymType{})
   125  		lval = &tokens[len(tokens)-1]
   126  		p.scanner.Scan(lval)
   127  
   128  		if preValID == BEGIN && lval.id == ATOMIC {
   129  			curFuncBodyCnt++
   130  		}
   131  		if curFuncBodyCnt > 0 && lval.id == END {
   132  			curFuncBodyCnt--
   133  		}
   134  		if lval.id == 0 || (curFuncBodyCnt == 0 && lval.id == ';') {
   135  			endPos := p.scanner.Pos()
   136  			if lval.id == ';' {
   137  				// Don't include the ending semicolon, if there is one, in the raw SQL.
   138  				endPos--
   139  			}
   140  			tokens = tokens[:len(tokens)-1]
   141  			return p.scanner.In()[startPos:endPos], tokens, (lval.id == 0)
   142  		}
   143  		lval.pos -= startPos
   144  	}
   145  }
   146  
   147  func (p *Parser) parseWithDepth(
   148  	depth int, sql string, nakedIntType *types.T,
   149  ) (statements.Statements, error) {
   150  	stmts := statements.Statements(p.stmtBuf[:0])
   151  	p.scanner.Init(sql)
   152  	defer p.scanner.Cleanup()
   153  	for {
   154  		sql, tokens, done := p.scanOneStmt()
   155  		stmt, err := p.parse(depth+1, sql, tokens, nakedIntType)
   156  		if err != nil {
   157  			return nil, err
   158  		}
   159  		if stmt.AST != nil {
   160  			stmts = append(stmts, stmt)
   161  		}
   162  		if done {
   163  			break
   164  		}
   165  	}
   166  	return stmts, nil
   167  }
   168  
   169  // parse parses a statement from the given scanned tokens.
   170  func (p *Parser) parse(
   171  	depth int, sql string, tokens []sqlSymType, nakedIntType *types.T,
   172  ) (statements.Statement[tree.Statement], error) {
   173  	p.lexer.init(sql, tokens, nakedIntType)
   174  	defer p.lexer.cleanup()
   175  	if p.parserImpl.Parse(&p.lexer) != 0 {
   176  		if p.lexer.lastError == nil {
   177  			// This should never happen -- there should be an error object
   178  			// every time Parse() returns nonzero. We're just playing safe
   179  			// here.
   180  			p.lexer.Error("syntax error")
   181  		}
   182  		err := p.lexer.lastError
   183  
   184  		// Compatibility with 19.1 telemetry: prefix the telemetry keys
   185  		// with the "syntax." prefix.
   186  		// TODO(knz): move the auto-prefixing of feature names to a
   187  		// higher level in the call stack.
   188  		tkeys := errors.GetTelemetryKeys(err)
   189  		if len(tkeys) > 0 {
   190  			for i := range tkeys {
   191  				tkeys[i] = "syntax." + tkeys[i]
   192  			}
   193  			err = errors.WithTelemetry(err, tkeys...)
   194  		}
   195  
   196  		return statements.Statement[tree.Statement]{}, err
   197  	}
   198  
   199  	return statements.Statement[tree.Statement]{
   200  		AST:             p.lexer.stmt,
   201  		SQL:             sql,
   202  		Comments:        p.scanner.Comments,
   203  		NumPlaceholders: p.lexer.numPlaceholders,
   204  		NumAnnotations:  p.lexer.numAnnotations,
   205  	}, nil
   206  }
   207  
   208  // unaryNegation constructs an AST node for a negation. This attempts
   209  // to preserve constant NumVals and embed the negative sign inside
   210  // them instead of wrapping in an UnaryExpr. This in turn ensures
   211  // that negative numbers get considered as a single constant
   212  // for the purpose of formatting and scrubbing.
   213  func unaryNegation(e tree.Expr) tree.Expr {
   214  	if cst, ok := e.(*tree.NumVal); ok {
   215  		cst.Negate()
   216  		return cst
   217  	}
   218  
   219  	// Common case.
   220  	return &tree.UnaryExpr{
   221  		Operator: tree.MakeUnaryOperator(tree.UnaryMinus),
   222  		Expr:     e,
   223  	}
   224  }
   225  
   226  // Parse parses a sql statement string and returns a list of Statements.
   227  func Parse(sql string) (statements.Statements, error) {
   228  	return ParseWithInt(sql, defaultNakedIntType)
   229  }
   230  
   231  // ParseWithInt parses a sql statement string and returns a list of
   232  // Statements. The INT token will result in the specified TInt type.
   233  func ParseWithInt(sql string, nakedIntType *types.T) (statements.Statements, error) {
   234  	var p Parser
   235  	return p.parseWithDepth(1, sql, nakedIntType)
   236  }
   237  
   238  // ParseOne parses a sql statement string, ensuring that it contains only a
   239  // single statement, and returns that Statement. ParseOne will always
   240  // interpret the INT and SERIAL types as 64-bit types, since this is
   241  // used in various internal-execution paths where we might receive
   242  // bits of SQL from other nodes. In general,earwe expect that all
   243  // user-generated SQL has been run through the ParseWithInt() function.
   244  func ParseOne(sql string) (statements.Statement[tree.Statement], error) {
   245  	return ParseOneWithInt(sql, defaultNakedIntType)
   246  }
   247  
   248  // ParseOneWithInt is similar to ParseOn but interprets the INT and SERIAL
   249  // types as the provided integer type.
   250  func ParseOneWithInt(
   251  	sql string, nakedIntType *types.T,
   252  ) (statements.Statement[tree.Statement], error) {
   253  	var p Parser
   254  	return p.parseOneWithInt(sql, nakedIntType)
   255  }
   256  
   257  // ParseQualifiedTableName parses a possibly qualified table name. The
   258  // table name must contain one or more name parts, using the full
   259  // input SQL syntax: each name part containing special characters, or
   260  // non-lowercase characters, must be enclosed in double quote. The
   261  // name may not be an invalid table name (the caller is responsible
   262  // for guaranteeing that only valid table names are provided as
   263  // input).
   264  func ParseQualifiedTableName(sql string) (*tree.TableName, error) {
   265  	name, err := ParseTableName(sql)
   266  	if err != nil {
   267  		return nil, err
   268  	}
   269  	tn := name.ToTableName()
   270  	return &tn, nil
   271  }
   272  
   273  // ParseTableName parses a table name. The table name must contain one
   274  // or more name parts, using the full input SQL syntax: each name
   275  // part containing special characters, or non-lowercase characters,
   276  // must be enclosed in double quote. The name may not be an invalid
   277  // table name (the caller is responsible for guaranteeing that only
   278  // valid table names are provided as input).
   279  func ParseTableName(sql string) (*tree.UnresolvedObjectName, error) {
   280  	// We wrap the name we want to parse into a dummy statement since our parser
   281  	// can only parse full statements.
   282  	stmt, err := ParseOne(fmt.Sprintf("ALTER TABLE %s RENAME TO x", sql))
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	rename, ok := stmt.AST.(*tree.RenameTable)
   287  	if !ok {
   288  		return nil, errors.AssertionFailedf("expected an ALTER TABLE statement, but found %T", stmt)
   289  	}
   290  	return rename.Name, nil
   291  }
   292  
   293  // ParseTablePattern parses a table pattern. The table name must contain one
   294  // or more name parts, using the full input SQL syntax: each name
   295  // part containing special characters, or non-lowercase characters,
   296  // must be enclosed in double quote. The name may not be an invalid
   297  // table name (the caller is responsible for guaranteeing that only
   298  // valid table names are provided as input).
   299  // The last part may be '*' to denote a wildcard.
   300  func ParseTablePattern(sql string) (tree.TablePattern, error) {
   301  	// We wrap the name we want to parse into a dummy statement since our parser
   302  	// can only parse full statements.
   303  	stmt, err := ParseOne(fmt.Sprintf("GRANT SELECT ON TABLE %s TO admin", sql))
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  	grant, ok := stmt.AST.(*tree.Grant)
   308  	if !ok {
   309  		return nil, errors.AssertionFailedf("expected a GRANT statement, but found %T", stmt)
   310  	}
   311  	if len(grant.Targets.Tables.TablePatterns) == 0 {
   312  		return nil, errors.AssertionFailedf("expected at least one pattern")
   313  	}
   314  	u := grant.Targets.Tables.TablePatterns[0]
   315  	un, ok := u.(*tree.UnresolvedName)
   316  	if !ok {
   317  		return nil, errors.AssertionFailedf("expected an unresolved name, but found %T", u)
   318  	}
   319  	return un.NormalizeTablePattern()
   320  }
   321  
   322  // parseExprsWithInt parses one or more sql expressions.
   323  func parseExprsWithInt(exprs []string, nakedIntType *types.T) (tree.Exprs, error) {
   324  	stmt, err := ParseOneWithInt(fmt.Sprintf("SET ROW (%s)", strings.Join(exprs, ",")), nakedIntType)
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  	set, ok := stmt.AST.(*tree.SetVar)
   329  	if !ok {
   330  		return nil, errors.AssertionFailedf("expected a SET statement, but found %T", stmt)
   331  	}
   332  	return set.Values, nil
   333  }
   334  
   335  // ParseExprs parses a comma-delimited sequence of SQL scalar
   336  // expressions. The caller is responsible for ensuring that the input
   337  // is, in fact, a comma-delimited sequence of SQL scalar expressions —
   338  // the results are undefined if the string contains invalid SQL
   339  // syntax.
   340  func ParseExprs(sql []string) (tree.Exprs, error) {
   341  	if len(sql) == 0 {
   342  		return tree.Exprs{}, nil
   343  	}
   344  	return parseExprsWithInt(sql, defaultNakedIntType)
   345  }
   346  
   347  // ParseExpr parses a SQL scalar expression. The caller is responsible
   348  // for ensuring that the input is, in fact, a valid SQL scalar
   349  // expression — the results are undefined if the string contains
   350  // invalid SQL syntax.
   351  func ParseExpr(sql string) (tree.Expr, error) {
   352  	return ParseExprWithInt(sql, defaultNakedIntType)
   353  }
   354  
   355  // ParseExprWithInt parses a SQL scalar expression, using the given
   356  // type when INT is used as type name in the SQL syntax. The caller is
   357  // responsible for ensuring that the input is, in fact, a valid SQL
   358  // scalar expression — the results are undefined if the string
   359  // contains invalid SQL syntax.
   360  func ParseExprWithInt(sql string, nakedIntType *types.T) (tree.Expr, error) {
   361  	exprs, err := parseExprsWithInt([]string{sql}, nakedIntType)
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  	if len(exprs) != 1 {
   366  		return nil, errors.AssertionFailedf("expected 1 expression, found %d", len(exprs))
   367  	}
   368  	return exprs[0], nil
   369  }
   370  
   371  // GetTypeReferenceFromName turns a type name into a type
   372  // reference. This supports only “simple” (single-identifier)
   373  // references to built-in types, when the identifer has already been
   374  // parsed away from the input SQL syntax.
   375  func GetTypeReferenceFromName(typeName tree.Name) (tree.ResolvableTypeReference, error) {
   376  	expr, err := ParseExpr(fmt.Sprintf("1::%s", typeName.String()))
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	cast, ok := expr.(*tree.CastExpr)
   382  	if !ok {
   383  		return nil, errors.AssertionFailedf("expected a tree.CastExpr, but found %T", expr)
   384  	}
   385  
   386  	return cast.Type, nil
   387  }
   388  
   389  // GetTypeFromValidSQLSyntax retrieves a type from its SQL syntax. The caller is
   390  // responsible for guaranteeing that the type expression is valid
   391  // SQL (or handling the resulting error). This includes verifying that complex
   392  // identifiers are enclosed in double quotes, etc.
   393  func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error) {
   394  	expr, err := ParseExpr(fmt.Sprintf("1::%s", sql))
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  	return GetTypeFromCastOrCollate(expr)
   399  }
   400  
   401  // GetTypeFromCastOrCollate returns the type of the given tree.Expr. The method
   402  // assumes that the expression is either tree.CastExpr or tree.CollateExpr
   403  // (which wraps the tree.CastExpr).
   404  func GetTypeFromCastOrCollate(expr tree.Expr) (tree.ResolvableTypeReference, error) {
   405  	// COLLATE clause has lower precedence than the cast, so if we have
   406  	// something like `1::STRING COLLATE en`, it'll be parsed as
   407  	// CollateExpr(CastExpr).
   408  	if collate, ok := expr.(*tree.CollateExpr); ok {
   409  		return types.MakeCollatedString(types.String, collate.Locale), nil
   410  	}
   411  
   412  	cast, ok := expr.(*tree.CastExpr)
   413  	if !ok {
   414  		return nil, errors.AssertionFailedf("expected a tree.CastExpr, but found %T", expr)
   415  	}
   416  
   417  	return cast.Type, nil
   418  }
   419  
   420  var errBitLengthNotPositive = pgerror.WithCandidateCode(
   421  	errors.New("length for type bit must be at least 1"), pgcode.InvalidParameterValue)
   422  
   423  // newBitType creates a new BIT type with the given bit width.
   424  func newBitType(width int32, varying bool) (*types.T, error) {
   425  	if width < 1 {
   426  		return nil, errBitLengthNotPositive
   427  	}
   428  	if varying {
   429  		return types.MakeVarBit(width), nil
   430  	}
   431  	return types.MakeBit(width), nil
   432  }
   433  
   434  var errFloatPrecAtLeast1 = pgerror.WithCandidateCode(
   435  	errors.New("precision for type float must be at least 1 bit"), pgcode.InvalidParameterValue)
   436  var errFloatPrecMax54 = pgerror.WithCandidateCode(
   437  	errors.New("precision for type float must be less than 54 bits"), pgcode.InvalidParameterValue)
   438  
   439  // newFloat creates a type for FLOAT with the given precision.
   440  func newFloat(prec int64) (*types.T, error) {
   441  	if prec < 1 {
   442  		return nil, errFloatPrecAtLeast1
   443  	}
   444  	if prec <= 24 {
   445  		return types.Float4, nil
   446  	}
   447  	if prec <= 54 {
   448  		return types.Float, nil
   449  	}
   450  	return nil, errFloatPrecMax54
   451  }
   452  
   453  // newDecimal creates a type for DECIMAL with the given precision and scale.
   454  func newDecimal(prec, scale int32) (*types.T, error) {
   455  	if scale > prec {
   456  		err := pgerror.WithCandidateCode(
   457  			errors.Newf("scale (%d) must be between 0 and precision (%d)", scale, prec),
   458  			pgcode.InvalidParameterValue)
   459  		return nil, err
   460  	}
   461  	return types.MakeDecimal(prec, scale), nil
   462  }
   463  
   464  // arrayOf creates a type alias for an array of the given element type and fixed
   465  // bounds. The bounds are currently ignored.
   466  func arrayOf(
   467  	ref tree.ResolvableTypeReference, bounds []int32,
   468  ) (tree.ResolvableTypeReference, error) {
   469  	// If the reference is a statically known type, then return an array type,
   470  	// rather than an array type reference.
   471  	if typ, ok := tree.GetStaticallyKnownType(ref); ok {
   472  		// Do not allow type unknown[]. This is consistent with Postgres' behavior.
   473  		if typ.Family() == types.UnknownFamily {
   474  			return nil, pgerror.Newf(pgcode.UndefinedObject, "type unknown[] does not exist")
   475  		}
   476  		if typ.Family() == types.VoidFamily {
   477  			return nil, pgerror.Newf(pgcode.UndefinedObject, "type void[] does not exist")
   478  		}
   479  		if err := types.CheckArrayElementType(typ); err != nil {
   480  			return nil, err
   481  		}
   482  		return types.MakeArray(typ), nil
   483  	}
   484  	return &tree.ArrayTypeReference{ElementType: ref}, nil
   485  }