github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/comments.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"fmt"
    21  	"strconv"
    22  	"strings"
    23  	"unicode"
    24  )
    25  
    26  const (
    27  	// DirectiveMultiShardAutocommit is the query comment directive to allow
    28  	// single round trip autocommit with a multi-shard statement.
    29  	DirectiveMultiShardAutocommit = "MULTI_SHARD_AUTOCOMMIT"
    30  	// DirectiveSkipQueryPlanCache skips query plan cache when set.
    31  	DirectiveSkipQueryPlanCache = "SKIP_QUERY_PLAN_CACHE"
    32  	// DirectiveQueryTimeout sets a query timeout in vtgate. Only supported for SELECTS.
    33  	DirectiveQueryTimeout = "QUERY_TIMEOUT_MS"
    34  	// DirectiveScatterErrorsAsWarnings enables partial success scatter select queries
    35  	DirectiveScatterErrorsAsWarnings = "SCATTER_ERRORS_AS_WARNINGS"
    36  	// DirectiveIgnoreMaxPayloadSize skips payload size validation when set.
    37  	DirectiveIgnoreMaxPayloadSize = "IGNORE_MAX_PAYLOAD_SIZE"
    38  	// DirectiveIgnoreMaxMemoryRows skips memory row validation when set.
    39  	DirectiveIgnoreMaxMemoryRows = "IGNORE_MAX_MEMORY_ROWS"
    40  	// DirectiveAllowScatter lets scatter plans pass through even when they are turned off by `no-scatter`.
    41  	DirectiveAllowScatter = "ALLOW_SCATTER"
    42  	// DirectiveAllowHashJoin lets the planner use hash join if possible
    43  	DirectiveAllowHashJoin = "ALLOW_HASH_JOIN"
    44  	// DirectiveQueryPlanner lets the user specify per query which planner should be used
    45  	DirectiveQueryPlanner = "PLANNER"
    46  )
    47  
    48  func isNonSpace(r rune) bool {
    49  	return !unicode.IsSpace(r)
    50  }
    51  
    52  // leadingCommentEnd returns the first index after all leading comments, or
    53  // 0 if there are no leading comments.
    54  func leadingCommentEnd(text string) (end int) {
    55  	hasComment := false
    56  	pos := 0
    57  	for pos < len(text) {
    58  		// Eat up any whitespace. Trailing whitespace will be considered part of
    59  		// the leading comments.
    60  		nextVisibleOffset := strings.IndexFunc(text[pos:], isNonSpace)
    61  		if nextVisibleOffset < 0 {
    62  			break
    63  		}
    64  		pos += nextVisibleOffset
    65  		remainingText := text[pos:]
    66  
    67  		// Found visible characters. Look for '/*' at the beginning
    68  		// and '*/' somewhere after that.
    69  		if len(remainingText) < 4 || remainingText[:2] != "/*" || remainingText[2] == '!' {
    70  			break
    71  		}
    72  		commentLength := 4 + strings.Index(remainingText[2:], "*/")
    73  		if commentLength < 4 {
    74  			// Missing end comment :/
    75  			break
    76  		}
    77  
    78  		hasComment = true
    79  		pos += commentLength
    80  	}
    81  
    82  	if hasComment {
    83  		return pos
    84  	}
    85  	return 0
    86  }
    87  
    88  // trailingCommentStart returns the first index of trailing comments.
    89  // If there are no trailing comments, returns the length of the input string.
    90  func trailingCommentStart(text string) (start int) {
    91  	hasComment := false
    92  	reducedLen := len(text)
    93  	for reducedLen > 0 {
    94  		// Eat up any whitespace. Leading whitespace will be considered part of
    95  		// the trailing comments.
    96  		nextReducedLen := strings.LastIndexFunc(text[:reducedLen], isNonSpace) + 1
    97  		if nextReducedLen == 0 {
    98  			break
    99  		}
   100  		reducedLen = nextReducedLen
   101  		if reducedLen < 4 || text[reducedLen-2:reducedLen] != "*/" {
   102  			break
   103  		}
   104  
   105  		// Find the beginning of the comment
   106  		startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*")
   107  		if startCommentPos < 0 || text[startCommentPos+2] == '!' {
   108  			// Badly formatted sql, or a special /*! comment
   109  			break
   110  		}
   111  
   112  		hasComment = true
   113  		reducedLen = startCommentPos
   114  	}
   115  
   116  	if hasComment {
   117  		return reducedLen
   118  	}
   119  	return len(text)
   120  }
   121  
   122  // MarginComments holds the leading and trailing comments that surround a query.
   123  type MarginComments struct {
   124  	Leading  string
   125  	Trailing string
   126  }
   127  
   128  // SplitMarginComments pulls out any leading or trailing comments from a raw sql query.
   129  // This function also trims leading (if there's a comment) and trailing whitespace.
   130  func SplitMarginComments(sql string) (query string, comments MarginComments) {
   131  	trailingStart := trailingCommentStart(sql)
   132  	leadingEnd := leadingCommentEnd(sql[:trailingStart])
   133  	comments = MarginComments{
   134  		Leading:  strings.TrimLeftFunc(sql[:leadingEnd], unicode.IsSpace),
   135  		Trailing: strings.TrimRightFunc(sql[trailingStart:], unicode.IsSpace),
   136  	}
   137  	return strings.TrimFunc(sql[leadingEnd:trailingStart], func(c rune) bool {
   138  		return unicode.IsSpace(c) || c == ';'
   139  	}), comments
   140  }
   141  
   142  // StripLeadingComments trims the SQL string and removes any leading comments
   143  func StripLeadingComments(sql string) string {
   144  	sql = strings.TrimFunc(sql, unicode.IsSpace)
   145  
   146  	for hasCommentPrefix(sql) {
   147  		switch sql[0] {
   148  		case '/':
   149  			// Multi line comment
   150  			index := strings.Index(sql, "*/")
   151  			if index <= 1 {
   152  				return sql
   153  			}
   154  			// don't strip /*! ... */ or /*!50700 ... */
   155  			if len(sql) > 2 && sql[2] == '!' {
   156  				return sql
   157  			}
   158  			sql = sql[index+2:]
   159  		case '-':
   160  			// Single line comment
   161  			index := strings.Index(sql, "\n")
   162  			if index == -1 {
   163  				return ""
   164  			}
   165  			sql = sql[index+1:]
   166  		}
   167  
   168  		sql = strings.TrimFunc(sql, unicode.IsSpace)
   169  	}
   170  
   171  	return sql
   172  }
   173  
   174  func hasCommentPrefix(sql string) bool {
   175  	return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-'))
   176  }
   177  
   178  // ExtractMysqlComment extracts the version and SQL from a comment-only query
   179  // such as /*!50708 sql here */
   180  func ExtractMysqlComment(sql string) (string, string) {
   181  	sql = sql[3 : len(sql)-2]
   182  
   183  	digitCount := 0
   184  	endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
   185  		digitCount++
   186  		return !unicode.IsDigit(c) || digitCount == 6
   187  	})
   188  	if endOfVersionIndex < 0 {
   189  		return "", ""
   190  	}
   191  	if endOfVersionIndex < 5 {
   192  		endOfVersionIndex = 0
   193  	}
   194  	version := sql[0:endOfVersionIndex]
   195  	innerSQL := strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
   196  
   197  	return version, innerSQL
   198  }
   199  
   200  const commentDirectivePreamble = "/*vt+"
   201  
   202  // CommentDirectives is the parsed representation for execution directives
   203  // conveyed in query comments
   204  type CommentDirectives map[string]interface{}
   205  
   206  // ExtractCommentDirectives parses the comment list for any execution directives
   207  // of the form:
   208  //
   209  //     /*vt+ OPTION_ONE=1 OPTION_TWO OPTION_THREE=abcd */
   210  //
   211  // It returns the map of the directive values or nil if there aren't any.
   212  func ExtractCommentDirectives(comments Comments) CommentDirectives {
   213  	if comments == nil {
   214  		return nil
   215  	}
   216  
   217  	var vals map[string]interface{}
   218  
   219  	for _, commentStr := range comments {
   220  		if commentStr[0:5] != commentDirectivePreamble {
   221  			continue
   222  		}
   223  
   224  		if vals == nil {
   225  			vals = make(map[string]interface{})
   226  		}
   227  
   228  		// Split on whitespace and ignore the first and last directive
   229  		// since they contain the comment start/end
   230  		directives := strings.Fields(commentStr)
   231  		for i := 1; i < len(directives)-1; i++ {
   232  			directive := directives[i]
   233  			sep := strings.IndexByte(directive, '=')
   234  
   235  			// No value is equivalent to a true boolean
   236  			if sep == -1 {
   237  				vals[directive] = true
   238  				continue
   239  			}
   240  
   241  			strVal := directive[sep+1:]
   242  			directive = directive[:sep]
   243  
   244  			intVal, err := strconv.Atoi(strVal)
   245  			if err == nil {
   246  				vals[directive] = intVal
   247  				continue
   248  			}
   249  
   250  			boolVal, err := strconv.ParseBool(strVal)
   251  			if err == nil {
   252  				vals[directive] = boolVal
   253  				continue
   254  			}
   255  
   256  			vals[directive] = strVal
   257  		}
   258  	}
   259  	return vals
   260  }
   261  
   262  // IsSet checks the directive map for the named directive and returns
   263  // true if the directive is set and has a true/false or 0/1 value
   264  func (d CommentDirectives) IsSet(key string) bool {
   265  	if d == nil {
   266  		return false
   267  	}
   268  
   269  	val, ok := d[key]
   270  	if !ok {
   271  		return false
   272  	}
   273  
   274  	boolVal, ok := val.(bool)
   275  	if ok {
   276  		return boolVal
   277  	}
   278  
   279  	intVal, ok := val.(int)
   280  	if ok {
   281  		return intVal == 1
   282  	}
   283  	return false
   284  }
   285  
   286  // GetString gets a directive value as string, with default value if not found
   287  func (d CommentDirectives) GetString(key string, defaultVal string) string {
   288  	val, ok := d[key]
   289  	if !ok {
   290  		return defaultVal
   291  	}
   292  	stringVal := fmt.Sprintf("%v", val)
   293  	if unquoted, err := strconv.Unquote(stringVal); err == nil {
   294  		stringVal = unquoted
   295  	}
   296  	return stringVal
   297  }
   298  
   299  // MultiShardAutocommitDirective returns true if multishard autocommit directive is set to true in query.
   300  func MultiShardAutocommitDirective(stmt Statement) bool {
   301  	switch stmt := stmt.(type) {
   302  	case *Insert:
   303  		directives := ExtractCommentDirectives(stmt.Comments)
   304  		if directives.IsSet(DirectiveMultiShardAutocommit) {
   305  			return true
   306  		}
   307  	case *Update:
   308  		directives := ExtractCommentDirectives(stmt.Comments)
   309  		if directives.IsSet(DirectiveMultiShardAutocommit) {
   310  			return true
   311  		}
   312  	case *Delete:
   313  		directives := ExtractCommentDirectives(stmt.Comments)
   314  		if directives.IsSet(DirectiveMultiShardAutocommit) {
   315  			return true
   316  		}
   317  	default:
   318  		return false
   319  	}
   320  	return false
   321  }
   322  
   323  // SkipQueryPlanCacheDirective returns true if skip query plan cache directive is set to true in query.
   324  func SkipQueryPlanCacheDirective(stmt Statement) bool {
   325  	switch stmt := stmt.(type) {
   326  	case *Select:
   327  		directives := ExtractCommentDirectives(stmt.Comments)
   328  		if directives.IsSet(DirectiveSkipQueryPlanCache) {
   329  			return true
   330  		}
   331  	case *Insert:
   332  		directives := ExtractCommentDirectives(stmt.Comments)
   333  		if directives.IsSet(DirectiveSkipQueryPlanCache) {
   334  			return true
   335  		}
   336  	case *Update:
   337  		directives := ExtractCommentDirectives(stmt.Comments)
   338  		if directives.IsSet(DirectiveSkipQueryPlanCache) {
   339  			return true
   340  		}
   341  	case *Delete:
   342  		directives := ExtractCommentDirectives(stmt.Comments)
   343  		if directives.IsSet(DirectiveSkipQueryPlanCache) {
   344  			return true
   345  		}
   346  	default:
   347  		return false
   348  	}
   349  	return false
   350  }
   351  
   352  // IgnoreMaxPayloadSizeDirective returns true if the max payload size override
   353  // directive is set to true.
   354  func IgnoreMaxPayloadSizeDirective(stmt Statement) bool {
   355  	switch stmt := stmt.(type) {
   356  	// For transactional statements, they should always be passed down and
   357  	// should not come into max payload size requirement.
   358  	case *Begin, *Commit, *Rollback, *Savepoint, *SRollback, *Release:
   359  		return true
   360  	case *Select:
   361  		directives := ExtractCommentDirectives(stmt.Comments)
   362  		return directives.IsSet(DirectiveIgnoreMaxPayloadSize)
   363  	case *Insert:
   364  		directives := ExtractCommentDirectives(stmt.Comments)
   365  		return directives.IsSet(DirectiveIgnoreMaxPayloadSize)
   366  	case *Update:
   367  		directives := ExtractCommentDirectives(stmt.Comments)
   368  		return directives.IsSet(DirectiveIgnoreMaxPayloadSize)
   369  	case *Delete:
   370  		directives := ExtractCommentDirectives(stmt.Comments)
   371  		return directives.IsSet(DirectiveIgnoreMaxPayloadSize)
   372  	default:
   373  		return false
   374  	}
   375  }
   376  
   377  // IgnoreMaxMaxMemoryRowsDirective returns true if the max memory rows override
   378  // directive is set to true.
   379  func IgnoreMaxMaxMemoryRowsDirective(stmt Statement) bool {
   380  	switch stmt := stmt.(type) {
   381  	case *Select:
   382  		directives := ExtractCommentDirectives(stmt.Comments)
   383  		return directives.IsSet(DirectiveIgnoreMaxMemoryRows)
   384  	case *Insert:
   385  		directives := ExtractCommentDirectives(stmt.Comments)
   386  		return directives.IsSet(DirectiveIgnoreMaxMemoryRows)
   387  	case *Update:
   388  		directives := ExtractCommentDirectives(stmt.Comments)
   389  		return directives.IsSet(DirectiveIgnoreMaxMemoryRows)
   390  	case *Delete:
   391  		directives := ExtractCommentDirectives(stmt.Comments)
   392  		return directives.IsSet(DirectiveIgnoreMaxMemoryRows)
   393  	default:
   394  		return false
   395  	}
   396  }
   397  
   398  // AllowScatterDirective returns true if the allow scatter override is set to true
   399  func AllowScatterDirective(stmt Statement) bool {
   400  	var directives CommentDirectives
   401  	switch stmt := stmt.(type) {
   402  	case *Select:
   403  		directives = ExtractCommentDirectives(stmt.Comments)
   404  	case *Insert:
   405  		directives = ExtractCommentDirectives(stmt.Comments)
   406  	case *Update:
   407  		directives = ExtractCommentDirectives(stmt.Comments)
   408  	case *Delete:
   409  		directives = ExtractCommentDirectives(stmt.Comments)
   410  	default:
   411  		return false
   412  	}
   413  	return directives.IsSet(DirectiveAllowScatter)
   414  }