github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqlparser/tracked_buffer.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"fmt"
    21  	"strings"
    22  )
    23  
    24  // NodeFormatter defines the signature of a custom node formatter
    25  // function that can be given to TrackedBuffer for code generation.
    26  type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
    27  
    28  // TrackedBuffer is used to rebuild a query from the ast.
    29  // bindLocations keeps track of locations in the buffer that
    30  // use bind variables for efficient future substitutions.
    31  // nodeFormatter is the formatting function the buffer will
    32  // use to format a node. By default(nil), it's FormatNode.
    33  // But you can supply a different formatting function if you
    34  // want to generate a query that's different from the default.
    35  type TrackedBuffer struct {
    36  	*strings.Builder
    37  	bindLocations []bindLocation
    38  	nodeFormatter NodeFormatter
    39  	literal       func(string) (int, error)
    40  	escape        bool
    41  	fast          bool
    42  }
    43  
    44  // NewTrackedBuffer creates a new TrackedBuffer.
    45  func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
    46  	buf := &TrackedBuffer{
    47  		Builder:       new(strings.Builder),
    48  		nodeFormatter: nodeFormatter,
    49  	}
    50  	buf.literal = buf.WriteString
    51  	buf.fast = nodeFormatter == nil
    52  	return buf
    53  }
    54  
    55  func (buf *TrackedBuffer) writeStringUpperCase(lit string) (int, error) {
    56  	// Upcasing is performed for ASCII only, following MySQL's behavior
    57  	buf.Grow(len(lit))
    58  	for i := 0; i < len(lit); i++ {
    59  		c := lit[i]
    60  		if 'a' <= c && c <= 'z' {
    61  			c -= 'a' - 'A'
    62  		}
    63  		buf.WriteByte(c)
    64  	}
    65  	return len(lit), nil
    66  }
    67  
    68  // SetUpperCase sets whether all SQL statements formatted by this TrackedBuffer will be normalized into
    69  // uppercase. By default, formatted statements are normalized into lowercase.
    70  // Enabling this option will prevent the optimized fastFormat routines from running.
    71  func (buf *TrackedBuffer) SetUpperCase(enable bool) {
    72  	buf.fast = false
    73  	if enable {
    74  		buf.literal = buf.writeStringUpperCase
    75  	} else {
    76  		buf.literal = buf.WriteString
    77  	}
    78  }
    79  
    80  // SetEscapeAllIdentifiers sets whether ALL identifiers in the serialized SQL query should be quoted
    81  // and escaped. By default, identifiers are only escaped if they match the name of a SQL keyword or they
    82  // contain characters that must be escaped.
    83  // Enabling this option will prevent the optimized fastFormat routines from running.
    84  func (buf *TrackedBuffer) SetEscapeAllIdentifiers(enable bool) {
    85  	buf.fast = false
    86  	buf.escape = enable
    87  }
    88  
    89  // WriteNode function, initiates the writing of a single SQLNode tree by passing
    90  // through to Myprintf with a default format string
    91  func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
    92  	buf.Myprintf("%v", node)
    93  	return buf
    94  }
    95  
    96  // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
    97  // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
    98  // which case it adds tracking info for future substitutions.
    99  // It adds parens as needed to follow precedence rules when printing expressions.
   100  // To handle parens correctly for left associative binary operators,
   101  // use %l and %r to tell the TrackedBuffer which value is on the LHS and RHS
   102  //
   103  // The name must be something other than the usual Printf() to avoid "go vet"
   104  // warnings due to our custom format specifiers.
   105  // *** THIS METHOD SHOULD NOT BE USED FROM ast.go. USE astPrintf INSTEAD ***
   106  func (buf *TrackedBuffer) Myprintf(format string, values ...any) {
   107  	buf.astPrintf(nil, format, values...)
   108  }
   109  
   110  func (buf *TrackedBuffer) printExpr(currentExpr Expr, expr Expr, left bool) {
   111  	if precedenceFor(currentExpr) == Syntactic {
   112  		expr.formatFast(buf)
   113  	} else {
   114  		needParens := needParens(currentExpr, expr, left)
   115  		if needParens {
   116  			buf.WriteByte('(')
   117  		}
   118  		expr.formatFast(buf)
   119  		if needParens {
   120  			buf.WriteByte(')')
   121  		}
   122  	}
   123  }
   124  
   125  // astPrintf is for internal use by the ast structs
   126  func (buf *TrackedBuffer) astPrintf(currentNode SQLNode, format string, values ...any) {
   127  	currentExpr, checkParens := currentNode.(Expr)
   128  	if checkParens {
   129  		// expressions that have Precedence Syntactic will never need parens
   130  		checkParens = precedenceFor(currentExpr) != Syntactic
   131  	}
   132  
   133  	end := len(format)
   134  	fieldnum := 0
   135  	for i := 0; i < end; {
   136  		lasti := i
   137  		for i < end && format[i] != '%' {
   138  			i++
   139  		}
   140  		if i > lasti {
   141  			_, _ = buf.literal(format[lasti:i])
   142  		}
   143  		if i >= end {
   144  			break
   145  		}
   146  		i++ // '%'
   147  
   148  		caseSensitive := false
   149  		if format[i] == '#' {
   150  			caseSensitive = true
   151  			i++
   152  		}
   153  
   154  		switch format[i] {
   155  		case 'c':
   156  			switch v := values[fieldnum].(type) {
   157  			case byte:
   158  				buf.WriteByte(v)
   159  			case rune:
   160  				buf.WriteRune(v)
   161  			default:
   162  				panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
   163  			}
   164  		case 's':
   165  			switch v := values[fieldnum].(type) {
   166  			case string:
   167  				if caseSensitive {
   168  					buf.WriteString(v)
   169  				} else {
   170  					_, _ = buf.literal(v)
   171  				}
   172  			default:
   173  				panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
   174  			}
   175  		case 'l', 'r', 'v':
   176  			left := format[i] != 'r'
   177  			value := values[fieldnum]
   178  			expr := getExpressionForParensEval(checkParens, value)
   179  
   180  			if expr == nil {
   181  				buf.formatter(value.(SQLNode))
   182  			} else {
   183  				needParens := needParens(currentExpr, expr, left)
   184  				if needParens {
   185  					buf.WriteByte('(')
   186  				}
   187  				buf.formatter(expr)
   188  				if needParens {
   189  					buf.WriteByte(')')
   190  				}
   191  			}
   192  		case 'd':
   193  			buf.WriteString(fmt.Sprintf("%d", values[fieldnum]))
   194  		case 'a':
   195  			buf.WriteArg("", values[fieldnum].(string))
   196  		default:
   197  			panic("unexpected")
   198  		}
   199  		fieldnum++
   200  		i++
   201  	}
   202  }
   203  
   204  func getExpressionForParensEval(checkParens bool, value any) Expr {
   205  	if checkParens {
   206  		expr, isExpr := value.(Expr)
   207  		if isExpr {
   208  			return expr
   209  		}
   210  	}
   211  	return nil
   212  }
   213  
   214  func (buf *TrackedBuffer) formatter(node SQLNode) {
   215  	switch {
   216  	case buf.fast:
   217  		node.formatFast(buf)
   218  	case buf.nodeFormatter != nil:
   219  		buf.nodeFormatter(buf, node)
   220  	default:
   221  		node.Format(buf)
   222  	}
   223  }
   224  
   225  // needParens says if we need a parenthesis
   226  // op is the operator we are printing
   227  // val is the value we are checking if we need parens around or not
   228  // left let's us know if the value is on the lhs or rhs of the operator
   229  func needParens(op, val Expr, left bool) bool {
   230  	// Values are atomic and never need parens
   231  	if IsValue(val) {
   232  		return false
   233  	}
   234  
   235  	if areBothISExpr(op, val) {
   236  		return true
   237  	}
   238  
   239  	opBinding := precedenceFor(op)
   240  	valBinding := precedenceFor(val)
   241  
   242  	if opBinding == Syntactic || valBinding == Syntactic {
   243  		return false
   244  	}
   245  
   246  	if left {
   247  		// for left associative operators, if the value is to the left of the operator,
   248  		// we only need parens if the order is higher for the value expression
   249  		return valBinding > opBinding
   250  	}
   251  
   252  	return valBinding >= opBinding
   253  }
   254  
   255  func areBothISExpr(op Expr, val Expr) bool {
   256  	_, isOpIS := op.(*IsExpr)
   257  	if isOpIS {
   258  		_, isValIS := val.(*IsExpr)
   259  		if isValIS {
   260  			// when using IS on an IS op, we need special handling
   261  			return true
   262  		}
   263  	}
   264  	return false
   265  }
   266  
   267  // WriteArg writes a value argument into the buffer along with
   268  // tracking information for future substitutions.
   269  func (buf *TrackedBuffer) WriteArg(prefix, arg string) {
   270  	buf.bindLocations = append(buf.bindLocations, bindLocation{
   271  		offset: buf.Len(),
   272  		length: len(prefix) + len(arg),
   273  	})
   274  	buf.WriteString(prefix)
   275  	buf.WriteString(arg)
   276  }
   277  
   278  // ParsedQuery returns a ParsedQuery that contains bind
   279  // locations for easy substitution.
   280  func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
   281  	return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
   282  }
   283  
   284  // HasBindVars returns true if the parsed query uses bind vars.
   285  func (buf *TrackedBuffer) HasBindVars() bool {
   286  	return len(buf.bindLocations) != 0
   287  }
   288  
   289  // BuildParsedQuery builds a ParsedQuery from the input.
   290  func BuildParsedQuery(in string, vars ...any) *ParsedQuery {
   291  	buf := NewTrackedBuffer(nil)
   292  	buf.Myprintf(in, vars...)
   293  	return buf.ParsedQuery()
   294  }
   295  
   296  // String returns a string representation of an SQLNode.
   297  func String(node SQLNode) string {
   298  	if node == nil {
   299  		return "<nil>"
   300  	}
   301  
   302  	buf := NewTrackedBuffer(nil)
   303  	node.formatFast(buf)
   304  	return buf.String()
   305  }
   306  
   307  // CanonicalString returns a canonical string representation of an SQLNode where all identifiers
   308  // are always escaped and all SQL syntax is in uppercase. This matches the canonical output from MySQL.
   309  func CanonicalString(node SQLNode) string {
   310  	if node == nil {
   311  		return "" // do not return '<nil>', which is Go syntax.
   312  	}
   313  
   314  	buf := NewTrackedBuffer(nil)
   315  	buf.SetUpperCase(true)
   316  	buf.SetEscapeAllIdentifiers(true)
   317  	node.Format(buf)
   318  	return buf.String()
   319  }