github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqlparser/tracked_buffer.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "fmt" 21 "strings" 22 ) 23 24 // NodeFormatter defines the signature of a custom node formatter 25 // function that can be given to TrackedBuffer for code generation. 26 type NodeFormatter func(buf *TrackedBuffer, node SQLNode) 27 28 // TrackedBuffer is used to rebuild a query from the ast. 29 // bindLocations keeps track of locations in the buffer that 30 // use bind variables for efficient future substitutions. 31 // nodeFormatter is the formatting function the buffer will 32 // use to format a node. By default(nil), it's FormatNode. 33 // But you can supply a different formatting function if you 34 // want to generate a query that's different from the default. 35 type TrackedBuffer struct { 36 *strings.Builder 37 bindLocations []bindLocation 38 nodeFormatter NodeFormatter 39 literal func(string) (int, error) 40 escape bool 41 fast bool 42 } 43 44 // NewTrackedBuffer creates a new TrackedBuffer. 45 func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer { 46 buf := &TrackedBuffer{ 47 Builder: new(strings.Builder), 48 nodeFormatter: nodeFormatter, 49 } 50 buf.literal = buf.WriteString 51 buf.fast = nodeFormatter == nil 52 return buf 53 } 54 55 func (buf *TrackedBuffer) writeStringUpperCase(lit string) (int, error) { 56 // Upcasing is performed for ASCII only, following MySQL's behavior 57 buf.Grow(len(lit)) 58 for i := 0; i < len(lit); i++ { 59 c := lit[i] 60 if 'a' <= c && c <= 'z' { 61 c -= 'a' - 'A' 62 } 63 buf.WriteByte(c) 64 } 65 return len(lit), nil 66 } 67 68 // SetUpperCase sets whether all SQL statements formatted by this TrackedBuffer will be normalized into 69 // uppercase. By default, formatted statements are normalized into lowercase. 70 // Enabling this option will prevent the optimized fastFormat routines from running. 71 func (buf *TrackedBuffer) SetUpperCase(enable bool) { 72 buf.fast = false 73 if enable { 74 buf.literal = buf.writeStringUpperCase 75 } else { 76 buf.literal = buf.WriteString 77 } 78 } 79 80 // SetEscapeAllIdentifiers sets whether ALL identifiers in the serialized SQL query should be quoted 81 // and escaped. By default, identifiers are only escaped if they match the name of a SQL keyword or they 82 // contain characters that must be escaped. 83 // Enabling this option will prevent the optimized fastFormat routines from running. 84 func (buf *TrackedBuffer) SetEscapeAllIdentifiers(enable bool) { 85 buf.fast = false 86 buf.escape = enable 87 } 88 89 // WriteNode function, initiates the writing of a single SQLNode tree by passing 90 // through to Myprintf with a default format string 91 func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer { 92 buf.Myprintf("%v", node) 93 return buf 94 } 95 96 // Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v), 97 // Node.Value(%s) and string(%s). It also allows a %a for a value argument, in 98 // which case it adds tracking info for future substitutions. 99 // It adds parens as needed to follow precedence rules when printing expressions. 100 // To handle parens correctly for left associative binary operators, 101 // use %l and %r to tell the TrackedBuffer which value is on the LHS and RHS 102 // 103 // The name must be something other than the usual Printf() to avoid "go vet" 104 // warnings due to our custom format specifiers. 105 // *** THIS METHOD SHOULD NOT BE USED FROM ast.go. USE astPrintf INSTEAD *** 106 func (buf *TrackedBuffer) Myprintf(format string, values ...any) { 107 buf.astPrintf(nil, format, values...) 108 } 109 110 func (buf *TrackedBuffer) printExpr(currentExpr Expr, expr Expr, left bool) { 111 if precedenceFor(currentExpr) == Syntactic { 112 expr.formatFast(buf) 113 } else { 114 needParens := needParens(currentExpr, expr, left) 115 if needParens { 116 buf.WriteByte('(') 117 } 118 expr.formatFast(buf) 119 if needParens { 120 buf.WriteByte(')') 121 } 122 } 123 } 124 125 // astPrintf is for internal use by the ast structs 126 func (buf *TrackedBuffer) astPrintf(currentNode SQLNode, format string, values ...any) { 127 currentExpr, checkParens := currentNode.(Expr) 128 if checkParens { 129 // expressions that have Precedence Syntactic will never need parens 130 checkParens = precedenceFor(currentExpr) != Syntactic 131 } 132 133 end := len(format) 134 fieldnum := 0 135 for i := 0; i < end; { 136 lasti := i 137 for i < end && format[i] != '%' { 138 i++ 139 } 140 if i > lasti { 141 _, _ = buf.literal(format[lasti:i]) 142 } 143 if i >= end { 144 break 145 } 146 i++ // '%' 147 148 caseSensitive := false 149 if format[i] == '#' { 150 caseSensitive = true 151 i++ 152 } 153 154 switch format[i] { 155 case 'c': 156 switch v := values[fieldnum].(type) { 157 case byte: 158 buf.WriteByte(v) 159 case rune: 160 buf.WriteRune(v) 161 default: 162 panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) 163 } 164 case 's': 165 switch v := values[fieldnum].(type) { 166 case string: 167 if caseSensitive { 168 buf.WriteString(v) 169 } else { 170 _, _ = buf.literal(v) 171 } 172 default: 173 panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v)) 174 } 175 case 'l', 'r', 'v': 176 left := format[i] != 'r' 177 value := values[fieldnum] 178 expr := getExpressionForParensEval(checkParens, value) 179 180 if expr == nil { 181 buf.formatter(value.(SQLNode)) 182 } else { 183 needParens := needParens(currentExpr, expr, left) 184 if needParens { 185 buf.WriteByte('(') 186 } 187 buf.formatter(expr) 188 if needParens { 189 buf.WriteByte(')') 190 } 191 } 192 case 'd': 193 buf.WriteString(fmt.Sprintf("%d", values[fieldnum])) 194 case 'a': 195 buf.WriteArg("", values[fieldnum].(string)) 196 default: 197 panic("unexpected") 198 } 199 fieldnum++ 200 i++ 201 } 202 } 203 204 func getExpressionForParensEval(checkParens bool, value any) Expr { 205 if checkParens { 206 expr, isExpr := value.(Expr) 207 if isExpr { 208 return expr 209 } 210 } 211 return nil 212 } 213 214 func (buf *TrackedBuffer) formatter(node SQLNode) { 215 switch { 216 case buf.fast: 217 node.formatFast(buf) 218 case buf.nodeFormatter != nil: 219 buf.nodeFormatter(buf, node) 220 default: 221 node.Format(buf) 222 } 223 } 224 225 // needParens says if we need a parenthesis 226 // op is the operator we are printing 227 // val is the value we are checking if we need parens around or not 228 // left let's us know if the value is on the lhs or rhs of the operator 229 func needParens(op, val Expr, left bool) bool { 230 // Values are atomic and never need parens 231 if IsValue(val) { 232 return false 233 } 234 235 if areBothISExpr(op, val) { 236 return true 237 } 238 239 opBinding := precedenceFor(op) 240 valBinding := precedenceFor(val) 241 242 if opBinding == Syntactic || valBinding == Syntactic { 243 return false 244 } 245 246 if left { 247 // for left associative operators, if the value is to the left of the operator, 248 // we only need parens if the order is higher for the value expression 249 return valBinding > opBinding 250 } 251 252 return valBinding >= opBinding 253 } 254 255 func areBothISExpr(op Expr, val Expr) bool { 256 _, isOpIS := op.(*IsExpr) 257 if isOpIS { 258 _, isValIS := val.(*IsExpr) 259 if isValIS { 260 // when using IS on an IS op, we need special handling 261 return true 262 } 263 } 264 return false 265 } 266 267 // WriteArg writes a value argument into the buffer along with 268 // tracking information for future substitutions. 269 func (buf *TrackedBuffer) WriteArg(prefix, arg string) { 270 buf.bindLocations = append(buf.bindLocations, bindLocation{ 271 offset: buf.Len(), 272 length: len(prefix) + len(arg), 273 }) 274 buf.WriteString(prefix) 275 buf.WriteString(arg) 276 } 277 278 // ParsedQuery returns a ParsedQuery that contains bind 279 // locations for easy substitution. 280 func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery { 281 return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations} 282 } 283 284 // HasBindVars returns true if the parsed query uses bind vars. 285 func (buf *TrackedBuffer) HasBindVars() bool { 286 return len(buf.bindLocations) != 0 287 } 288 289 // BuildParsedQuery builds a ParsedQuery from the input. 290 func BuildParsedQuery(in string, vars ...any) *ParsedQuery { 291 buf := NewTrackedBuffer(nil) 292 buf.Myprintf(in, vars...) 293 return buf.ParsedQuery() 294 } 295 296 // String returns a string representation of an SQLNode. 297 func String(node SQLNode) string { 298 if node == nil { 299 return "<nil>" 300 } 301 302 buf := NewTrackedBuffer(nil) 303 node.formatFast(buf) 304 return buf.String() 305 } 306 307 // CanonicalString returns a canonical string representation of an SQLNode where all identifiers 308 // are always escaped and all SQL syntax is in uppercase. This matches the canonical output from MySQL. 309 func CanonicalString(node SQLNode) string { 310 if node == nil { 311 return "" // do not return '<nil>', which is Go syntax. 312 } 313 314 buf := NewTrackedBuffer(nil) 315 buf.SetUpperCase(true) 316 buf.SetEscapeAllIdentifiers(true) 317 node.Format(buf) 318 return buf.String() 319 }