github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqlparser/normalizer.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "fmt" 21 "strconv" 22 23 "github.com/vedadiyan/sqlparser/pkg/sqltypes" 24 25 querypb "github.com/vedadiyan/sqlparser/pkg/query" 26 ) 27 28 // BindVars is a set of reserved bind variables from a SQL statement 29 type BindVars map[string]struct{} 30 31 // Normalize changes the statement to use bind values, and 32 // updates the bind vars to those values. The supplied prefix 33 // is used to generate the bind var names. The function ensures 34 // that there are no collisions with existing bind vars. 35 // Within Select constructs, bind vars are deduped. This allows 36 // us to identify vindex equality. Otherwise, every value is 37 // treated as distinct. 38 func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error { 39 nz := newNormalizer(reserved, bindVars) 40 _ = SafeRewrite(stmt, nz.walkStatementDown, nz.walkStatementUp) 41 return nz.err 42 } 43 44 type normalizer struct { 45 bindVars map[string]*querypb.BindVariable 46 reserved *ReservedVars 47 vals map[string]string 48 err error 49 inDerived bool 50 } 51 52 func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) *normalizer { 53 return &normalizer{ 54 bindVars: bindVars, 55 reserved: reserved, 56 vals: make(map[string]string), 57 } 58 } 59 60 // walkStatementUp is one half of the top level walk function. 61 func (nz *normalizer) walkStatementUp(cursor *Cursor) bool { 62 if nz.err != nil { 63 return false 64 } 65 node, isLiteral := cursor.Node().(*Literal) 66 if !isLiteral { 67 return true 68 } 69 nz.convertLiteral(node, cursor) 70 return nz.err == nil // only continue if we haven't found any errors 71 } 72 73 // walkStatementDown is the top level walk function. 74 // If it encounters a Select, it switches to a mode 75 // where variables are deduped. 76 func (nz *normalizer) walkStatementDown(node, parent SQLNode) bool { 77 switch node := node.(type) { 78 // no need to normalize the statement types 79 case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, DDLStatement, *SRollback, *Release, *OtherAdmin, *OtherRead: 80 return false 81 case *Select: 82 _, isDerived := parent.(*DerivedTable) 83 var tmp bool 84 tmp, nz.inDerived = nz.inDerived, isDerived 85 _ = SafeRewrite(node, nz.walkDownSelect, nz.walkUpSelect) 86 // Don't continue 87 nz.inDerived = tmp 88 return false 89 case *ComparisonExpr: 90 nz.convertComparison(node) 91 case *UpdateExpr: 92 nz.convertUpdateExpr(node) 93 case *ColName, TableName: 94 // Common node types that never contain Literal or ListArgs but create a lot of object 95 // allocations. 96 return false 97 case *ConvertType: // we should not rewrite the type description 98 return false 99 } 100 return nz.err == nil // only continue if we haven't found any errors 101 } 102 103 // walkDownSelect normalizes the AST in Select mode. 104 func (nz *normalizer) walkDownSelect(node, parent SQLNode) bool { 105 switch node := node.(type) { 106 case *Select: 107 _, isDerived := parent.(*DerivedTable) 108 if !isDerived { 109 return true 110 } 111 var tmp bool 112 tmp, nz.inDerived = nz.inDerived, isDerived 113 // initiating a new AST walk here means that we might change something while walking down on the tree, 114 // but since we are only changing literals, we can be safe that we are not changing the SELECT struct, 115 // only something much further down, and that should be safe 116 _ = SafeRewrite(node, nz.walkDownSelect, nz.walkUpSelect) 117 // Don't continue 118 nz.inDerived = tmp 119 return false 120 case SelectExprs: 121 return !nz.inDerived 122 case *ComparisonExpr: 123 nz.convertComparison(node) 124 case *FramePoint: 125 // do not make a bind var for rows and range 126 return false 127 case *ColName, TableName: 128 // Common node types that never contain Literals or ListArgs but create a lot of object 129 // allocations. 130 return false 131 case *ConvertType: 132 // we should not rewrite the type description 133 return false 134 } 135 return nz.err == nil // only continue if we haven't found any errors 136 } 137 138 // walkUpSelect normalizes the Literals in Select mode. 139 func (nz *normalizer) walkUpSelect(cursor *Cursor) bool { 140 if nz.err != nil { 141 return false 142 } 143 node, isLiteral := cursor.Node().(*Literal) 144 if !isLiteral { 145 return true 146 } 147 parent := cursor.Parent() 148 switch parent.(type) { 149 case *Order, GroupBy: 150 return false 151 case *Limit: 152 nz.convertLiteral(node, cursor) 153 default: 154 nz.convertLiteralDedup(node, cursor) 155 } 156 return nz.err == nil // only continue if we haven't found any errors 157 } 158 159 func validateLiteral(node *Literal) (err error) { 160 switch node.Type { 161 case DateVal: 162 _, err = ParseDate(node.Val) 163 case TimeVal: 164 _, err = ParseTime(node.Val) 165 case TimestampVal: 166 _, err = ParseDateTime(node.Val) 167 } 168 return err 169 } 170 171 func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { 172 err := validateLiteral(node) 173 if err != nil { 174 nz.err = err 175 } 176 177 // If value is too long, don't dedup. 178 // Such values are most likely not for vindexes. 179 // We save a lot of CPU because we avoid building 180 // the key for them. 181 if len(node.Val) > 256 { 182 nz.convertLiteral(node, cursor) 183 return 184 } 185 186 // Make the bindvar 187 bval := SQLToBindvar(node) 188 if bval == nil { 189 return 190 } 191 192 // Check if there's a bindvar for that value already. 193 key := keyFor(bval, node) 194 bvname, ok := nz.vals[key] 195 if !ok { 196 // If there's no such bindvar, make a new one. 197 bvname = nz.reserved.nextUnusedVar() 198 nz.vals[key] = bvname 199 nz.bindVars[bvname] = bval 200 } 201 202 // Modify the AST node to a bindvar. 203 cursor.Replace(NewArgument(bvname)) 204 } 205 206 func keyFor(bval *querypb.BindVariable, lit *Literal) string { 207 if bval.Type != sqltypes.VarBinary && bval.Type != sqltypes.VarChar { 208 return lit.Val 209 } 210 211 // Prefixing strings with "'" ensures that a string 212 // and number that have the same representation don't 213 // collide. 214 return "'" + lit.Val 215 216 } 217 218 // convertLiteral converts an Literal without the dedup. 219 func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { 220 err := validateLiteral(node) 221 if err != nil { 222 nz.err = err 223 } 224 225 bval := SQLToBindvar(node) 226 if bval == nil { 227 return 228 } 229 230 bvname := nz.reserved.nextUnusedVar() 231 nz.bindVars[bvname] = bval 232 233 cursor.Replace(NewArgument(bvname)) 234 } 235 236 // convertComparison attempts to convert IN clauses to 237 // use the list bind var construct. If it fails, it returns 238 // with no change made. The walk function will then continue 239 // and iterate on converting each individual value into separate 240 // bind vars. 241 func (nz *normalizer) convertComparison(node *ComparisonExpr) { 242 switch node.Operator { 243 case InOp, NotInOp: 244 nz.rewriteInComparisons(node) 245 default: 246 nz.rewriteOtherComparisons(node) 247 } 248 } 249 250 func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) { 251 newR := nz.parameterize(node.Left, node.Right) 252 if newR != nil { 253 node.Right = newR 254 } 255 } 256 257 func (nz *normalizer) parameterize(left, right Expr) Expr { 258 col, ok := left.(*ColName) 259 if !ok { 260 return nil 261 } 262 lit, ok := right.(*Literal) 263 if !ok { 264 return nil 265 } 266 err := validateLiteral(lit) 267 if err != nil { 268 nz.err = err 269 return nil 270 } 271 272 bval := SQLToBindvar(lit) 273 if bval == nil { 274 return nil 275 } 276 key := keyFor(bval, lit) 277 bvname := nz.decideBindVarName(key, lit, col, bval) 278 return Argument(bvname) 279 } 280 281 func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string { 282 if len(lit.Val) <= 256 { 283 // first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name 284 bvname, ok := nz.vals[key] 285 if ok { 286 return bvname 287 } 288 } 289 290 // If there's no such bindvar, or we have a big value, make a new one. 291 // Big values are most likely not for vindexes. 292 // We save a lot of CPU because we avoid building 293 bvname := nz.reserved.ReserveColName(col) 294 nz.vals[key] = bvname 295 nz.bindVars[bvname] = bval 296 297 return bvname 298 } 299 300 func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) { 301 tupleVals, ok := node.Right.(ValTuple) 302 if !ok { 303 return 304 } 305 306 // The RHS is a tuple of values. 307 // Make a list bindvar. 308 bvals := &querypb.BindVariable{ 309 Type: querypb.Type_TUPLE, 310 } 311 for _, val := range tupleVals { 312 bval := SQLToBindvar(val) 313 if bval == nil { 314 return 315 } 316 bvals.Values = append(bvals.Values, &querypb.Value{ 317 Type: bval.Type, 318 Value: bval.Value, 319 }) 320 } 321 bvname := nz.reserved.nextUnusedVar() 322 nz.bindVars[bvname] = bvals 323 // Modify RHS to be a list bindvar. 324 node.Right = ListArg(bvname) 325 } 326 327 func (nz *normalizer) convertUpdateExpr(node *UpdateExpr) { 328 newR := nz.parameterize(node.Name, node.Expr) 329 if newR != nil { 330 node.Expr = newR 331 } 332 } 333 334 func SQLToBindvar(node SQLNode) *querypb.BindVariable { 335 if node, ok := node.(*Literal); ok { 336 var v sqltypes.Value 337 var err error 338 switch node.Type { 339 case StrVal: 340 v, err = sqltypes.NewValue(sqltypes.VarChar, node.Bytes()) 341 case IntVal: 342 v, err = sqltypes.NewValue(sqltypes.Int64, node.Bytes()) 343 case FloatVal: 344 v, err = sqltypes.NewValue(sqltypes.Float64, node.Bytes()) 345 case DecimalVal: 346 v, err = sqltypes.NewValue(sqltypes.Decimal, node.Bytes()) 347 case HexNum: 348 v, err = sqltypes.NewValue(sqltypes.HexNum, node.Bytes()) 349 case HexVal: 350 // We parse the `x'7b7d'` string literal into a hex encoded string of `7b7d` in the parser 351 // We need to re-encode it back to the original MySQL query format before passing it on as a bindvar value to MySQL 352 var vbytes []byte 353 vbytes, err = node.encodeHexOrBitValToMySQLQueryFormat() 354 if err != nil { 355 return nil 356 } 357 v, err = sqltypes.NewValue(sqltypes.HexVal, vbytes) 358 case BitVal: 359 // Convert bit value to hex number in parameterized query format 360 var ui uint64 361 ui, err = strconv.ParseUint(string(node.Bytes()), 2, 64) 362 if err != nil { 363 return nil 364 } 365 v, err = sqltypes.NewValue(sqltypes.HexNum, []byte(fmt.Sprintf("0x%x", ui))) 366 case DateVal: 367 v, err = sqltypes.NewValue(sqltypes.Date, node.Bytes()) 368 case TimeVal: 369 v, err = sqltypes.NewValue(sqltypes.Time, node.Bytes()) 370 case TimestampVal: 371 // This is actually a DATETIME MySQL type. The timestamp literal 372 // syntax is part of the SQL standard and MySQL DATETIME matches 373 // the type best. 374 v, err = sqltypes.NewValue(sqltypes.Datetime, node.Bytes()) 375 default: 376 return nil 377 } 378 if err != nil { 379 return nil 380 } 381 return sqltypes.ValueBindVariable(v) 382 } 383 return nil 384 } 385 386 // GetBindvars returns a map of the bind vars referenced in the statement. 387 func GetBindvars(stmt Statement) map[string]struct{} { 388 bindvars := make(map[string]struct{}) 389 _ = Walk(func(node SQLNode) (kontinue bool, err error) { 390 switch node := node.(type) { 391 case *ColName, TableName: 392 // Common node types that never contain expressions but create a lot of object 393 // allocations. 394 return false, nil 395 case Argument: 396 bindvars[string(node)] = struct{}{} 397 case ListArg: 398 bindvars[string(node)] = struct{}{} 399 } 400 return true, nil 401 }, stmt) 402 return bindvars 403 }