vitess.io/vitess@v0.16.2/go/vt/sqlparser/analyzer.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  // analyzer.go contains utility analysis functions.
    20  
    21  import (
    22  	"fmt"
    23  	"strings"
    24  	"unicode"
    25  )
    26  
    27  // StatementType encodes the type of a SQL statement
    28  type StatementType int
    29  
    30  // These constants are used to identify the SQL statement type.
    31  // Changing this list will require reviewing all calls to Preview.
    32  const (
    33  	StmtSelect StatementType = iota
    34  	StmtStream
    35  	StmtInsert
    36  	StmtReplace
    37  	StmtUpdate
    38  	StmtDelete
    39  	StmtDDL
    40  	StmtBegin
    41  	StmtCommit
    42  	StmtRollback
    43  	StmtSet
    44  	StmtShow
    45  	StmtUse
    46  	StmtOther
    47  	StmtUnknown
    48  	StmtComment
    49  	StmtPriv
    50  	StmtExplain
    51  	StmtSavepoint
    52  	StmtSRollback
    53  	StmtRelease
    54  	StmtVStream
    55  	StmtLockTables
    56  	StmtUnlockTables
    57  	StmtFlush
    58  	StmtCallProc
    59  	StmtRevert
    60  	StmtShowMigrationLogs
    61  	StmtCommentOnly
    62  )
    63  
    64  // ASTToStatementType returns a StatementType from an AST stmt
    65  func ASTToStatementType(stmt Statement) StatementType {
    66  	switch stmt.(type) {
    67  	case *Select, *Union:
    68  		return StmtSelect
    69  	case *Insert:
    70  		return StmtInsert
    71  	case *Update:
    72  		return StmtUpdate
    73  	case *Delete:
    74  		return StmtDelete
    75  	case *Set:
    76  		return StmtSet
    77  	case *Show:
    78  		return StmtShow
    79  	case DDLStatement, DBDDLStatement, *AlterVschema:
    80  		return StmtDDL
    81  	case *RevertMigration:
    82  		return StmtRevert
    83  	case *ShowMigrationLogs:
    84  		return StmtShowMigrationLogs
    85  	case *Use:
    86  		return StmtUse
    87  	case *OtherRead, *OtherAdmin, *Load:
    88  		return StmtOther
    89  	case Explain, *VExplainStmt:
    90  		return StmtExplain
    91  	case *Begin:
    92  		return StmtBegin
    93  	case *Commit:
    94  		return StmtCommit
    95  	case *Rollback:
    96  		return StmtRollback
    97  	case *Savepoint:
    98  		return StmtSavepoint
    99  	case *SRollback:
   100  		return StmtSRollback
   101  	case *Release:
   102  		return StmtRelease
   103  	case *LockTables:
   104  		return StmtLockTables
   105  	case *UnlockTables:
   106  		return StmtUnlockTables
   107  	case *Flush:
   108  		return StmtFlush
   109  	case *CallProc:
   110  		return StmtCallProc
   111  	case *Stream:
   112  		return StmtStream
   113  	case *VStream:
   114  		return StmtVStream
   115  	case *CommentOnly:
   116  		return StmtCommentOnly
   117  	default:
   118  		return StmtUnknown
   119  	}
   120  }
   121  
   122  // CanNormalize takes Statement and returns if the statement can be normalized.
   123  func CanNormalize(stmt Statement) bool {
   124  	switch stmt.(type) {
   125  	case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter
   126  		return true
   127  	}
   128  	return false
   129  }
   130  
   131  // CachePlan takes Statement and returns true if the query plan should be cached
   132  func CachePlan(stmt Statement) bool {
   133  	var comments *ParsedComments
   134  	switch stmt := stmt.(type) {
   135  	case *Select:
   136  		comments = stmt.Comments
   137  	case *Insert:
   138  		comments = stmt.Comments
   139  	case *Update:
   140  		comments = stmt.Comments
   141  	case *Delete:
   142  		comments = stmt.Comments
   143  	case *Union, *Stream:
   144  		return true
   145  	default:
   146  		return false
   147  	}
   148  	return !comments.Directives().IsSet(DirectiveSkipQueryPlanCache)
   149  }
   150  
   151  // MustRewriteAST takes Statement and returns true if RewriteAST must run on it for correct execution irrespective of user flags.
   152  func MustRewriteAST(stmt Statement, hasSelectLimit bool) bool {
   153  	switch node := stmt.(type) {
   154  	case *Set:
   155  		return true
   156  	case *Show:
   157  		switch node.Internal.(type) {
   158  		case *ShowBasic:
   159  			return true
   160  		}
   161  		return false
   162  	case SelectStatement:
   163  		return hasSelectLimit
   164  	}
   165  	return false
   166  }
   167  
   168  // Preview analyzes the beginning of the query using a simpler and faster
   169  // textual comparison to identify the statement type.
   170  func Preview(sql string) StatementType {
   171  	trimmed := StripLeadingComments(sql)
   172  
   173  	if strings.Index(trimmed, "/*!") == 0 {
   174  		return StmtComment
   175  	}
   176  
   177  	isNotLetter := func(r rune) bool { return !unicode.IsLetter(r) }
   178  	firstWord := strings.TrimLeftFunc(trimmed, isNotLetter)
   179  
   180  	if end := strings.IndexFunc(firstWord, unicode.IsSpace); end != -1 {
   181  		firstWord = firstWord[:end]
   182  	}
   183  	// Comparison is done in order of priority.
   184  	loweredFirstWord := strings.ToLower(firstWord)
   185  	switch loweredFirstWord {
   186  	case "select":
   187  		return StmtSelect
   188  	case "stream":
   189  		return StmtStream
   190  	case "vstream":
   191  		return StmtVStream
   192  	case "revert":
   193  		return StmtRevert
   194  	case "insert":
   195  		return StmtInsert
   196  	case "replace":
   197  		return StmtReplace
   198  	case "update":
   199  		return StmtUpdate
   200  	case "delete":
   201  		return StmtDelete
   202  	case "savepoint":
   203  		return StmtSavepoint
   204  	case "lock":
   205  		return StmtLockTables
   206  	case "unlock":
   207  		return StmtUnlockTables
   208  	}
   209  	// For the following statements it is not sufficient to rely
   210  	// on loweredFirstWord. This is because they are not statements
   211  	// in the grammar and we are relying on Preview to parse them.
   212  	// For instance, we don't want: "BEGIN JUNK" to be parsed
   213  	// as StmtBegin.
   214  	trimmedNoComments, _ := SplitMarginComments(trimmed)
   215  	switch strings.ToLower(trimmedNoComments) {
   216  	case "begin", "start transaction":
   217  		return StmtBegin
   218  	case "commit":
   219  		return StmtCommit
   220  	case "rollback":
   221  		return StmtRollback
   222  	}
   223  	switch loweredFirstWord {
   224  	case "create", "alter", "rename", "drop", "truncate":
   225  		return StmtDDL
   226  	case "flush":
   227  		return StmtFlush
   228  	case "set":
   229  		return StmtSet
   230  	case "show":
   231  		return StmtShow
   232  	case "use":
   233  		return StmtUse
   234  	case "describe", "desc", "explain":
   235  		return StmtExplain
   236  	case "analyze", "repair", "optimize":
   237  		return StmtOther
   238  	case "grant", "revoke":
   239  		return StmtPriv
   240  	case "release":
   241  		return StmtRelease
   242  	case "rollback":
   243  		return StmtSRollback
   244  	}
   245  	return StmtUnknown
   246  }
   247  
   248  func (s StatementType) String() string {
   249  	switch s {
   250  	case StmtSelect:
   251  		return "SELECT"
   252  	case StmtStream:
   253  		return "STREAM"
   254  	case StmtVStream:
   255  		return "VSTREAM"
   256  	case StmtRevert:
   257  		return "REVERT"
   258  	case StmtInsert:
   259  		return "INSERT"
   260  	case StmtReplace:
   261  		return "REPLACE"
   262  	case StmtUpdate:
   263  		return "UPDATE"
   264  	case StmtDelete:
   265  		return "DELETE"
   266  	case StmtDDL:
   267  		return "DDL"
   268  	case StmtBegin:
   269  		return "BEGIN"
   270  	case StmtCommit:
   271  		return "COMMIT"
   272  	case StmtRollback:
   273  		return "ROLLBACK"
   274  	case StmtSet:
   275  		return "SET"
   276  	case StmtShow:
   277  		return "SHOW"
   278  	case StmtUse:
   279  		return "USE"
   280  	case StmtOther:
   281  		return "OTHER"
   282  	case StmtPriv:
   283  		return "PRIV"
   284  	case StmtExplain:
   285  		return "EXPLAIN"
   286  	case StmtSavepoint:
   287  		return "SAVEPOINT"
   288  	case StmtSRollback:
   289  		return "SAVEPOINT_ROLLBACK"
   290  	case StmtRelease:
   291  		return "RELEASE"
   292  	case StmtLockTables:
   293  		return "LOCK_TABLES"
   294  	case StmtUnlockTables:
   295  		return "UNLOCK_TABLES"
   296  	case StmtFlush:
   297  		return "FLUSH"
   298  	case StmtCallProc:
   299  		return "CALL_PROC"
   300  	case StmtCommentOnly:
   301  		return "COMMENT_ONLY"
   302  	default:
   303  		return "UNKNOWN"
   304  	}
   305  }
   306  
   307  // IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
   308  func IsDML(sql string) bool {
   309  	switch Preview(sql) {
   310  	case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
   311  		return true
   312  	}
   313  	return false
   314  }
   315  
   316  // IsDMLStatement returns true if the query is an INSERT, UPDATE or DELETE statement.
   317  func IsDMLStatement(stmt Statement) bool {
   318  	switch stmt.(type) {
   319  	case *Insert, *Update, *Delete:
   320  		return true
   321  	}
   322  
   323  	return false
   324  }
   325  
   326  // TableFromStatement returns the qualified table name for the query.
   327  // This works only for select statements.
   328  func TableFromStatement(sql string) (TableName, error) {
   329  	stmt, err := Parse(sql)
   330  	if err != nil {
   331  		return TableName{}, err
   332  	}
   333  	sel, ok := stmt.(*Select)
   334  	if !ok {
   335  		return TableName{}, fmt.Errorf("unrecognized statement: %s", sql)
   336  	}
   337  	if len(sel.From) != 1 {
   338  		return TableName{}, fmt.Errorf("table expression is complex")
   339  	}
   340  	aliased, ok := sel.From[0].(*AliasedTableExpr)
   341  	if !ok {
   342  		return TableName{}, fmt.Errorf("table expression is complex")
   343  	}
   344  	tableName, ok := aliased.Expr.(TableName)
   345  	if !ok {
   346  		return TableName{}, fmt.Errorf("table expression is complex")
   347  	}
   348  	return tableName, nil
   349  }
   350  
   351  // GetTableName returns the table name from the SimpleTableExpr
   352  // only if it's a simple expression. Otherwise, it returns "".
   353  func GetTableName(node SimpleTableExpr) IdentifierCS {
   354  	if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
   355  		return n.Name
   356  	}
   357  	// sub-select or '.' expression
   358  	return NewIdentifierCS("")
   359  }
   360  
   361  // IsColName returns true if the Expr is a *ColName.
   362  func IsColName(node Expr) bool {
   363  	_, ok := node.(*ColName)
   364  	return ok
   365  }
   366  
   367  // IsValue returns true if the Expr is a string, integral or value arg.
   368  // NULL is not considered to be a value.
   369  func IsValue(node Expr) bool {
   370  	switch v := node.(type) {
   371  	case Argument:
   372  		return true
   373  	case *Literal:
   374  		switch v.Type {
   375  		case StrVal, HexVal, IntVal:
   376  			return true
   377  		}
   378  	}
   379  	return false
   380  }
   381  
   382  // IsNull returns true if the Expr is SQL NULL
   383  func IsNull(node Expr) bool {
   384  	switch node.(type) {
   385  	case *NullVal:
   386  		return true
   387  	}
   388  	return false
   389  }
   390  
   391  // IsSimpleTuple returns true if the Expr is a ValTuple that
   392  // contains simple values or if it's a list arg.
   393  func IsSimpleTuple(node Expr) bool {
   394  	switch vals := node.(type) {
   395  	case ValTuple:
   396  		for _, n := range vals {
   397  			if !IsValue(n) {
   398  				return false
   399  			}
   400  		}
   401  		return true
   402  	case ListArg:
   403  		return true
   404  	}
   405  	// It's a subquery
   406  	return false
   407  }
   408  
   409  // IsLockingFunc returns true for all functions that are used to work with mysql advisory locks
   410  func IsLockingFunc(node Expr) bool {
   411  	switch node.(type) {
   412  	case *LockingFunc:
   413  		return true
   414  	}
   415  	return false
   416  }