
     1  /*
     2  Copyright 2019 The Vitess Authors.
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    17  package sqlparser
    19  import (
    20  	"strconv"
    21  	"strings"
    22  	"unicode"
    24  	querypb ""
    25  )
    27  const (
    28  	// DirectiveMultiShardAutocommit is the query comment directive to allow
    29  	// single round trip autocommit with a multi-shard statement.
    30  	DirectiveMultiShardAutocommit = "MULTI_SHARD_AUTOCOMMIT"
    31  	// DirectiveSkipQueryPlanCache skips query plan cache when set.
    32  	DirectiveSkipQueryPlanCache = "SKIP_QUERY_PLAN_CACHE"
    33  	// DirectiveQueryTimeout sets a query timeout in vtgate. Only supported for SELECTS.
    34  	DirectiveQueryTimeout = "QUERY_TIMEOUT_MS"
    35  	// DirectiveScatterErrorsAsWarnings enables partial success scatter select queries
    36  	DirectiveScatterErrorsAsWarnings = "SCATTER_ERRORS_AS_WARNINGS"
    37  	// DirectiveIgnoreMaxPayloadSize skips payload size validation when set.
    38  	DirectiveIgnoreMaxPayloadSize = "IGNORE_MAX_PAYLOAD_SIZE"
    39  	// DirectiveIgnoreMaxMemoryRows skips memory row validation when set.
    40  	DirectiveIgnoreMaxMemoryRows = "IGNORE_MAX_MEMORY_ROWS"
    41  	// DirectiveAllowScatter lets scatter plans pass through even when they are turned off by `no-scatter`.
    42  	DirectiveAllowScatter = "ALLOW_SCATTER"
    43  	// DirectiveAllowHashJoin lets the planner use hash join if possible
    44  	DirectiveAllowHashJoin = "ALLOW_HASH_JOIN"
    45  	// DirectiveQueryPlanner lets the user specify per query which planner should be used
    46  	DirectiveQueryPlanner = "PLANNER"
    47  	// DirectiveVExplainRunDMLQueries tells vexplain queries/all that it is okay to also run the query.
    48  	DirectiveVExplainRunDMLQueries = "EXECUTE_DML_QUERIES"
    49  	// DirectiveConsolidator enables the query consolidator.
    50  	DirectiveConsolidator = "CONSOLIDATOR"
    51  )
    53  func isNonSpace(r rune) bool {
    54  	return !unicode.IsSpace(r)
    55  }
    57  // leadingCommentEnd returns the first index after all leading comments, or
    58  // 0 if there are no leading comments.
    59  func leadingCommentEnd(text string) (end int) {
    60  	hasComment := false
    61  	pos := 0
    62  	for pos < len(text) {
    63  		// Eat up any whitespace. Trailing whitespace will be considered part of
    64  		// the leading comments.
    65  		nextVisibleOffset := strings.IndexFunc(text[pos:], isNonSpace)
    66  		if nextVisibleOffset < 0 {
    67  			break
    68  		}
    69  		pos += nextVisibleOffset
    70  		remainingText := text[pos:]
    72  		// Found visible characters. Look for '/*' at the beginning
    73  		// and '*/' somewhere after that.
    74  		if len(remainingText) < 4 || remainingText[:2] != "/*" || remainingText[2] == '!' {
    75  			break
    76  		}
    77  		commentLength := 4 + strings.Index(remainingText[2:], "*/")
    78  		if commentLength < 4 {
    79  			// Missing end comment :/
    80  			break
    81  		}
    83  		hasComment = true
    84  		pos += commentLength
    85  	}
    87  	if hasComment {
    88  		return pos
    89  	}
    90  	return 0
    91  }
    93  // trailingCommentStart returns the first index of trailing comments.
    94  // If there are no trailing comments, returns the length of the input string.
    95  func trailingCommentStart(text string) (start int) {
    96  	hasComment := false
    97  	reducedLen := len(text)
    98  	for reducedLen > 0 {
    99  		// Eat up any whitespace. Leading whitespace will be considered part of
   100  		// the trailing comments.
   101  		nextReducedLen := strings.LastIndexFunc(text[:reducedLen], isNonSpace) + 1
   102  		if nextReducedLen == 0 {
   103  			break
   104  		}
   105  		reducedLen = nextReducedLen
   106  		if reducedLen < 4 || text[reducedLen-2:reducedLen] != "*/" {
   107  			break
   108  		}
   110  		// Find the beginning of the comment
   111  		startCommentPos := strings.LastIndex(text[:reducedLen-2], "/*")
   112  		if startCommentPos < 0 || text[startCommentPos+2] == '!' {
   113  			// Badly formatted sql, or a special /*! comment
   114  			break
   115  		}
   117  		hasComment = true
   118  		reducedLen = startCommentPos
   119  	}
   121  	if hasComment {
   122  		return reducedLen
   123  	}
   124  	return len(text)
   125  }
   127  // MarginComments holds the leading and trailing comments that surround a query.
   128  type MarginComments struct {
   129  	Leading  string
   130  	Trailing string
   131  }
   133  // SplitMarginComments pulls out any leading or trailing comments from a raw sql query.
   134  // This function also trims leading (if there's a comment) and trailing whitespace.
   135  func SplitMarginComments(sql string) (query string, comments MarginComments) {
   136  	trailingStart := trailingCommentStart(sql)
   137  	leadingEnd := leadingCommentEnd(sql[:trailingStart])
   138  	comments = MarginComments{
   139  		Leading:  strings.TrimLeftFunc(sql[:leadingEnd], unicode.IsSpace),
   140  		Trailing: strings.TrimRightFunc(sql[trailingStart:], unicode.IsSpace),
   141  	}
   142  	return strings.TrimFunc(sql[leadingEnd:trailingStart], func(c rune) bool {
   143  		return unicode.IsSpace(c) || c == ';'
   144  	}), comments
   145  }
   147  // StripLeadingComments trims the SQL string and removes any leading comments
   148  func StripLeadingComments(sql string) string {
   149  	sql = strings.TrimFunc(sql, unicode.IsSpace)
   151  	for hasCommentPrefix(sql) {
   152  		switch sql[0] {
   153  		case '/':
   154  			// Multi line comment
   155  			index := strings.Index(sql, "*/")
   156  			if index <= 1 {
   157  				return sql
   158  			}
   159  			// don't strip /*! ... */ or /*!50700 ... */
   160  			if len(sql) > 2 && sql[2] == '!' {
   161  				return sql
   162  			}
   163  			sql = sql[index+2:]
   164  		case '-':
   165  			// Single line comment
   166  			index := strings.Index(sql, "\n")
   167  			if index == -1 {
   168  				return ""
   169  			}
   170  			sql = sql[index+1:]
   171  		}
   173  		sql = strings.TrimFunc(sql, unicode.IsSpace)
   174  	}
   176  	return sql
   177  }
   179  func hasCommentPrefix(sql string) bool {
   180  	return len(sql) > 1 && ((sql[0] == '/' && sql[1] == '*') || (sql[0] == '-' && sql[1] == '-'))
   181  }
   183  // ExtractMysqlComment extracts the version and SQL from a comment-only query
   184  // such as /*!50708 sql here */
   185  func ExtractMysqlComment(sql string) (string, string) {
   186  	sql = sql[3 : len(sql)-2]
   188  	digitCount := 0
   189  	endOfVersionIndex := strings.IndexFunc(sql, func(c rune) bool {
   190  		digitCount++
   191  		return !unicode.IsDigit(c) || digitCount == 6
   192  	})
   193  	if endOfVersionIndex < 0 {
   194  		return "", ""
   195  	}
   196  	if endOfVersionIndex < 5 {
   197  		endOfVersionIndex = 0
   198  	}
   199  	version := sql[0:endOfVersionIndex]
   200  	innerSQL := strings.TrimFunc(sql[endOfVersionIndex:], unicode.IsSpace)
   202  	return version, innerSQL
   203  }
   205  const commentDirectivePreamble = "/*vt+"
   207  // CommentDirectives is the parsed representation for execution directives
   208  // conveyed in query comments
   209  type CommentDirectives struct {
   210  	m map[string]string
   211  }
   213  // Directives parses the comment list for any execution directives
   214  // of the form:
   215  //
   216  //	/*vt+ OPTION_ONE=1 OPTION_TWO OPTION_THREE=abcd */
   217  //
   218  // It returns the map of the directive values or nil if there aren't any.
   219  func (c *ParsedComments) Directives() *CommentDirectives {
   220  	if c == nil {
   221  		return nil
   222  	}
   223  	if c._directives == nil {
   224  		c._directives = &CommentDirectives{m: make(map[string]string)}
   226  		for _, commentStr := range c.comments {
   227  			if commentStr[0:5] != commentDirectivePreamble {
   228  				continue
   229  			}
   231  			// Split on whitespace and ignore the first and last directive
   232  			// since they contain the comment start/end
   233  			directives := strings.Fields(commentStr)
   234  			for i := 1; i < len(directives)-1; i++ {
   235  				directive, val, ok := strings.Cut(directives[i], "=")
   236  				if !ok {
   237  					val = "true"
   238  				}
   239  				c._directives.m[strings.ToLower(directive)] = val
   240  			}
   241  		}
   242  	}
   243  	return c._directives
   244  }
   246  func (c *ParsedComments) Length() int {
   247  	if c == nil {
   248  		return 0
   249  	}
   250  	return len(c.comments)
   251  }
   253  func (c *ParsedComments) Prepend(comment string) Comments {
   254  	if c == nil {
   255  		return Comments{comment}
   256  	}
   257  	comments := make(Comments, 0, len(c.comments)+1)
   258  	comments = append(comments, comment)
   259  	comments = append(comments, c.comments...)
   260  	return comments
   261  }
   263  // IsSet checks the directive map for the named directive and returns
   264  // true if the directive is set and has a true/false or 0/1 value
   265  func (d *CommentDirectives) IsSet(key string) bool {
   266  	if d == nil {
   267  		return false
   268  	}
   269  	val, found := d.m[strings.ToLower(key)]
   270  	if !found {
   271  		return false
   272  	}
   273  	// ParseBool handles "0", "1", "true", "false" and all similars
   274  	set, _ := strconv.ParseBool(val)
   275  	return set
   276  }
   278  // GetString gets a directive value as string, with default value if not found
   279  func (d *CommentDirectives) GetString(key string, defaultVal string) (string, bool) {
   280  	if d == nil {
   281  		return "", false
   282  	}
   283  	val, ok := d.m[strings.ToLower(key)]
   284  	if !ok {
   285  		return defaultVal, false
   286  	}
   287  	if unquoted, err := strconv.Unquote(val); err == nil {
   288  		return unquoted, true
   289  	}
   290  	return val, true
   291  }
   293  // MultiShardAutocommitDirective returns true if multishard autocommit directive is set to true in query.
   294  func MultiShardAutocommitDirective(stmt Statement) bool {
   295  	var comments *ParsedComments
   296  	switch stmt := stmt.(type) {
   297  	case *Insert:
   298  		comments = stmt.Comments
   299  	case *Update:
   300  		comments = stmt.Comments
   301  	case *Delete:
   302  		comments = stmt.Comments
   303  	}
   304  	return comments != nil && comments.Directives().IsSet(DirectiveMultiShardAutocommit)
   305  }
   307  // SkipQueryPlanCacheDirective returns true if skip query plan cache directive is set to true in query.
   308  func SkipQueryPlanCacheDirective(stmt Statement) bool {
   309  	var comments *ParsedComments
   310  	switch stmt := stmt.(type) {
   311  	case *Select:
   312  		comments = stmt.Comments
   313  	case *Insert:
   314  		comments = stmt.Comments
   315  	case *Update:
   316  		comments = stmt.Comments
   317  	case *Delete:
   318  		comments = stmt.Comments
   319  	}
   320  	return comments != nil && comments.Directives().IsSet(DirectiveSkipQueryPlanCache)
   321  }
   323  // IgnoreMaxPayloadSizeDirective returns true if the max payload size override
   324  // directive is set to true.
   325  func IgnoreMaxPayloadSizeDirective(stmt Statement) bool {
   326  	var comments *ParsedComments
   327  	switch stmt := stmt.(type) {
   328  	// For transactional statements, they should always be passed down and
   329  	// should not come into max payload size requirement.
   330  	case *Begin, *Commit, *Rollback, *Savepoint, *SRollback, *Release:
   331  		return true
   332  	case *Select:
   333  		comments = stmt.Comments
   334  	case *Insert:
   335  		comments = stmt.Comments
   336  	case *Update:
   337  		comments = stmt.Comments
   338  	case *Delete:
   339  		comments = stmt.Comments
   340  	}
   341  	return comments != nil && comments.Directives().IsSet(DirectiveIgnoreMaxPayloadSize)
   342  }
   344  // IgnoreMaxMaxMemoryRowsDirective returns true if the max memory rows override
   345  // directive is set to true.
   346  func IgnoreMaxMaxMemoryRowsDirective(stmt Statement) bool {
   347  	var comments *ParsedComments
   348  	switch stmt := stmt.(type) {
   349  	case *Select:
   350  		comments = stmt.Comments
   351  	case *Insert:
   352  		comments = stmt.Comments
   353  	case *Update:
   354  		comments = stmt.Comments
   355  	case *Delete:
   356  		comments = stmt.Comments
   357  	}
   358  	return comments != nil && comments.Directives().IsSet(DirectiveIgnoreMaxMemoryRows)
   359  }
   361  // AllowScatterDirective returns true if the allow scatter override is set to true
   362  func AllowScatterDirective(stmt Statement) bool {
   363  	var comments *ParsedComments
   364  	switch stmt := stmt.(type) {
   365  	case *Select:
   366  		comments = stmt.Comments
   367  	case *Insert:
   368  		comments = stmt.Comments
   369  	case *Update:
   370  		comments = stmt.Comments
   371  	case *Delete:
   372  		comments = stmt.Comments
   373  	}
   374  	return comments != nil && comments.Directives().IsSet(DirectiveAllowScatter)
   375  }
   377  // Consolidator returns the consolidator option.
   378  func Consolidator(stmt Statement) querypb.ExecuteOptions_Consolidator {
   379  	var comments *ParsedComments
   380  	switch stmt := stmt.(type) {
   381  	case *Select:
   382  		comments = stmt.Comments
   383  	default:
   384  		return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED
   385  	}
   386  	if comments == nil {
   387  		return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED
   388  	}
   389  	directives := comments.Directives()
   390  	strv, isSet := directives.GetString(DirectiveConsolidator, "")
   391  	if !isSet {
   392  		return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED
   393  	}
   394  	if i32v, ok := querypb.ExecuteOptions_Consolidator_value["CONSOLIDATOR_"+strings.ToUpper(strv)]; ok {
   395  		return querypb.ExecuteOptions_Consolidator(i32v)
   396  	}
   397  	return querypb.ExecuteOptions_CONSOLIDATOR_UNSPECIFIED
   398  }