github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/tracked_buffer.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"fmt"
    21  	"strings"
    22  )
    23  
    24  // NodeFormatter defines the signature of a custom node formatter
    25  // function that can be given to TrackedBuffer for code generation.
    26  type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
    27  
    28  // TrackedBuffer is used to rebuild a query from the ast.
    29  // bindLocations keeps track of locations in the buffer that
    30  // use bind variables for efficient future substitutions.
    31  // nodeFormatter is the formatting function the buffer will
    32  // use to format a node. By default(nil), it's FormatNode.
    33  // But you can supply a different formatting function if you
    34  // want to generate a query that's different from the default.
    35  type TrackedBuffer struct {
    36  	*strings.Builder
    37  	bindLocations []bindLocation
    38  	nodeFormatter NodeFormatter
    39  }
    40  
    41  // NewTrackedBuffer creates a new TrackedBuffer.
    42  func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
    43  	return &TrackedBuffer{
    44  		Builder:       new(strings.Builder),
    45  		nodeFormatter: nodeFormatter,
    46  	}
    47  }
    48  
    49  // WriteNode function, initiates the writing of a single SQLNode tree by passing
    50  // through to Myprintf with a default format string
    51  func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
    52  	buf.Myprintf("%v", node)
    53  	return buf
    54  }
    55  
    56  // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
    57  // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
    58  // which case it adds tracking info for future substitutions.
    59  // It adds parens as needed to follow precedence rules when printing expressions.
    60  // To handle parens correctly for left associative binary operators,
    61  // use %l and %r to tell the TrackedBuffer which value is on the LHS and RHS
    62  //
    63  // The name must be something other than the usual Printf() to avoid "go vet"
    64  // warnings due to our custom format specifiers.
    65  // *** THIS METHOD SHOULD NOT BE USED FROM ast.go. USE astPrintf INSTEAD ***
    66  func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
    67  	buf.astPrintf(nil, format, values...)
    68  }
    69  
    70  func (buf *TrackedBuffer) printExpr(currentExpr Expr, expr Expr, left bool) {
    71  	if precedenceFor(currentExpr) == Syntactic {
    72  		expr.formatFast(buf)
    73  	} else {
    74  		needParens := needParens(currentExpr, expr, left)
    75  		if needParens {
    76  			buf.WriteByte('(')
    77  		}
    78  		expr.formatFast(buf)
    79  		if needParens {
    80  			buf.WriteByte(')')
    81  		}
    82  	}
    83  }
    84  
    85  // astPrintf is for internal use by the ast structs
    86  func (buf *TrackedBuffer) astPrintf(currentNode SQLNode, format string, values ...interface{}) {
    87  	currentExpr, checkParens := currentNode.(Expr)
    88  	if checkParens {
    89  		// expressions that have Precedence Syntactic will never need parens
    90  		checkParens = precedenceFor(currentExpr) != Syntactic
    91  	}
    92  
    93  	end := len(format)
    94  	fieldnum := 0
    95  	for i := 0; i < end; {
    96  		lasti := i
    97  		for i < end && format[i] != '%' {
    98  			i++
    99  		}
   100  		if i > lasti {
   101  			buf.WriteString(format[lasti:i])
   102  		}
   103  		if i >= end {
   104  			break
   105  		}
   106  		i++ // '%'
   107  		token := format[i]
   108  		switch token {
   109  		case 'c':
   110  			switch v := values[fieldnum].(type) {
   111  			case byte:
   112  				buf.WriteByte(v)
   113  			case rune:
   114  				buf.WriteRune(v)
   115  			default:
   116  				panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
   117  			}
   118  		case 's':
   119  			switch v := values[fieldnum].(type) {
   120  			case []byte:
   121  				buf.Write(v)
   122  			case string:
   123  				buf.WriteString(v)
   124  			default:
   125  				panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
   126  			}
   127  		case 'l', 'r', 'v':
   128  			left := token != 'r'
   129  			value := values[fieldnum]
   130  			expr := getExpressionForParensEval(checkParens, value)
   131  
   132  			if expr == nil {
   133  				buf.formatter(value.(SQLNode))
   134  			} else {
   135  				needParens := needParens(currentExpr, expr, left)
   136  				if needParens {
   137  					buf.WriteByte('(')
   138  				}
   139  				buf.formatter(expr)
   140  				if needParens {
   141  					buf.WriteByte(')')
   142  				}
   143  			}
   144  		case 'a':
   145  			buf.WriteArg("", values[fieldnum].(string))
   146  		default:
   147  			panic("unexpected")
   148  		}
   149  		fieldnum++
   150  		i++
   151  	}
   152  }
   153  
   154  func getExpressionForParensEval(checkParens bool, value interface{}) Expr {
   155  	if checkParens {
   156  		expr, isExpr := value.(Expr)
   157  		if isExpr {
   158  			return expr
   159  		}
   160  	}
   161  	return nil
   162  }
   163  
   164  func (buf *TrackedBuffer) formatter(node SQLNode) {
   165  	if buf.nodeFormatter == nil {
   166  		node.formatFast(buf)
   167  	} else {
   168  		buf.nodeFormatter(buf, node)
   169  	}
   170  }
   171  
   172  // needParens says if we need a parenthesis
   173  // op is the operator we are printing
   174  // val is the value we are checking if we need parens around or not
   175  // left let's us know if the value is on the lhs or rhs of the operator
   176  func needParens(op, val Expr, left bool) bool {
   177  	// Values are atomic and never need parens
   178  	if IsValue(val) {
   179  		return false
   180  	}
   181  
   182  	if areBothISExpr(op, val) {
   183  		return true
   184  	}
   185  
   186  	opBinding := precedenceFor(op)
   187  	valBinding := precedenceFor(val)
   188  
   189  	if opBinding == Syntactic || valBinding == Syntactic {
   190  		return false
   191  	}
   192  
   193  	if left {
   194  		// for left associative operators, if the value is to the left of the operator,
   195  		// we only need parens if the order is higher for the value expression
   196  		return valBinding > opBinding
   197  	}
   198  
   199  	return valBinding >= opBinding
   200  }
   201  
   202  func areBothISExpr(op Expr, val Expr) bool {
   203  	_, isOpIS := op.(*IsExpr)
   204  	if isOpIS {
   205  		_, isValIS := val.(*IsExpr)
   206  		if isValIS {
   207  			// when using IS on an IS op, we need special handling
   208  			return true
   209  		}
   210  	}
   211  	return false
   212  }
   213  
   214  // WriteArg writes a value argument into the buffer along with
   215  // tracking information for future substitutions.
   216  func (buf *TrackedBuffer) WriteArg(prefix, arg string) {
   217  	buf.bindLocations = append(buf.bindLocations, bindLocation{
   218  		offset: buf.Len(),
   219  		length: len(prefix) + len(arg),
   220  	})
   221  	buf.WriteString(prefix)
   222  	buf.WriteString(arg)
   223  }
   224  
   225  // ParsedQuery returns a ParsedQuery that contains bind
   226  // locations for easy substitution.
   227  func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
   228  	return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
   229  }
   230  
   231  // HasBindVars returns true if the parsed query uses bind vars.
   232  func (buf *TrackedBuffer) HasBindVars() bool {
   233  	return len(buf.bindLocations) != 0
   234  }
   235  
   236  // BuildParsedQuery builds a ParsedQuery from the input.
   237  func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
   238  	buf := NewTrackedBuffer(nil)
   239  	buf.Myprintf(in, vars...)
   240  	return buf.ParsedQuery()
   241  }