github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/normalizer.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "github.com/team-ide/go-dialect/vitess/sqltypes" 21 22 querypb "github.com/team-ide/go-dialect/vitess/query" 23 ) 24 25 // BindVars is a set of reserved bind variables from a SQL statement 26 type BindVars map[string]struct{} 27 28 // Normalize changes the statement to use bind values, and 29 // updates the bind vars to those values. The supplied prefix 30 // is used to generate the bind var names. The function ensures 31 // that there are no collisions with existing bind vars. 32 // Within Select constructs, bind vars are deduped. This allows 33 // us to identify vindex equality. Otherwise, every value is 34 // treated as distinct. 35 func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error { 36 nz := newNormalizer(reserved, bindVars) 37 _ = Rewrite(stmt, nz.WalkStatement, nil) 38 return nz.err 39 } 40 41 type normalizer struct { 42 bindVars map[string]*querypb.BindVariable 43 reserved *ReservedVars 44 vals map[string]string 45 err error 46 } 47 48 func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) *normalizer { 49 return &normalizer{ 50 bindVars: bindVars, 51 reserved: reserved, 52 vals: make(map[string]string), 53 } 54 } 55 56 // WalkStatement is the top level walk function. 57 // If it encounters a Select, it switches to a mode 58 // where variables are deduped. 59 func (nz *normalizer) WalkStatement(cursor *Cursor) bool { 60 switch node := cursor.Node().(type) { 61 // no need to normalize the statement types 62 case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, *SetTransaction, DDLStatement, *SRollback, *Release, *OtherAdmin, *OtherRead: 63 return false 64 case *Select: 65 _ = Rewrite(node, nz.WalkSelect, nil) 66 // Don't continue 67 return false 68 case *Literal: 69 nz.convertLiteral(node, cursor) 70 case *ComparisonExpr: 71 nz.convertComparison(node) 72 case *ColName, TableName: 73 // Common node types that never contain Literal or ListArgs but create a lot of object 74 // allocations. 75 return false 76 case *ConvertType: // we should not rewrite the type description 77 return false 78 } 79 return nz.err == nil // only continue if we haven't found any errors 80 } 81 82 // WalkSelect normalizes the AST in Select mode. 83 func (nz *normalizer) WalkSelect(cursor *Cursor) bool { 84 switch node := cursor.Node().(type) { 85 case *Literal: 86 nz.convertLiteralDedup(node, cursor) 87 case *ComparisonExpr: 88 nz.convertComparison(node) 89 case *ColName, TableName: 90 // Common node types that never contain Literals or ListArgs but create a lot of object 91 // allocations. 92 return false 93 case OrderBy, GroupBy: 94 // do not make a bind var for order by column_position 95 return false 96 case *ConvertType: 97 // we should not rewrite the type description 98 return false 99 } 100 return nz.err == nil // only continue if we haven't found any errors 101 } 102 103 func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { 104 // If value is too long, don't dedup. 105 // Such values are most likely not for vindexes. 106 // We save a lot of CPU because we avoid building 107 // the key for them. 108 if len(node.Val) > 256 { 109 nz.convertLiteral(node, cursor) 110 return 111 } 112 113 // Make the bindvar 114 bval := nz.sqlToBindvar(node) 115 if bval == nil { 116 return 117 } 118 119 // Check if there's a bindvar for that value already. 120 var key string 121 if bval.Type == sqltypes.VarBinary || bval.Type == sqltypes.VarChar { 122 // Prefixing strings with "'" ensures that a string 123 // and number that have the same representation don't 124 // collide. 125 key = "'" + node.Val 126 } else { 127 key = node.Val 128 } 129 bvname, ok := nz.vals[key] 130 if !ok { 131 // If there's no such bindvar, make a new one. 132 bvname = nz.reserved.nextUnusedVar() 133 nz.vals[key] = bvname 134 nz.bindVars[bvname] = bval 135 } 136 137 // Modify the AST node to a bindvar. 138 cursor.Replace(NewArgument(bvname)) 139 } 140 141 // convertLiteral converts an Literal without the dedup. 142 func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { 143 bval := nz.sqlToBindvar(node) 144 if bval == nil { 145 return 146 } 147 148 bvname := nz.reserved.nextUnusedVar() 149 nz.bindVars[bvname] = bval 150 151 cursor.Replace(NewArgument(bvname)) 152 } 153 154 // convertComparison attempts to convert IN clauses to 155 // use the list bind var construct. If it fails, it returns 156 // with no change made. The walk function will then continue 157 // and iterate on converting each individual value into separate 158 // bind vars. 159 func (nz *normalizer) convertComparison(node *ComparisonExpr) { 160 if node.Operator != InOp && node.Operator != NotInOp { 161 return 162 } 163 tupleVals, ok := node.Right.(ValTuple) 164 if !ok { 165 return 166 } 167 // The RHS is a tuple of values. 168 // Make a list bindvar. 169 bvals := &querypb.BindVariable{ 170 Type: querypb.Type_TUPLE, 171 } 172 for _, val := range tupleVals { 173 bval := nz.sqlToBindvar(val) 174 if bval == nil { 175 return 176 } 177 bvals.Values = append(bvals.Values, &querypb.Value{ 178 Type: bval.Type, 179 Value: bval.Value, 180 }) 181 } 182 bvname := nz.reserved.nextUnusedVar() 183 nz.bindVars[bvname] = bvals 184 // Modify RHS to be a list bindvar. 185 node.Right = ListArg(bvname) 186 } 187 188 func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable { 189 if node, ok := node.(*Literal); ok { 190 var v sqltypes.Value 191 var err error 192 switch node.Type { 193 case StrVal: 194 v, err = sqltypes.NewValue(sqltypes.VarChar, node.Bytes()) 195 case IntVal: 196 v, err = sqltypes.NewValue(sqltypes.Int64, node.Bytes()) 197 case FloatVal: 198 v, err = sqltypes.NewValue(sqltypes.Float64, node.Bytes()) 199 case DecimalVal: 200 v, err = sqltypes.NewValue(sqltypes.Decimal, node.Bytes()) 201 case HexNum: 202 v, err = sqltypes.NewValue(sqltypes.HexNum, node.Bytes()) 203 case HexVal: 204 // We parse the `x'7b7d'` string literal into a hex encoded string of `7b7d` in the parser 205 // We need to re-encode it back to the original MySQL query format before passing it on as a bindvar value to MySQL 206 var vbytes []byte 207 vbytes, err = node.encodeHexValToMySQLQueryFormat() 208 if err != nil { 209 return nil 210 } 211 v, err = sqltypes.NewValue(sqltypes.HexVal, vbytes) 212 default: 213 return nil 214 } 215 if err != nil { 216 return nil 217 } 218 return sqltypes.ValueBindVariable(v) 219 } 220 return nil 221 } 222 223 // GetBindvars returns a map of the bind vars referenced in the statement. 224 func GetBindvars(stmt Statement) map[string]struct{} { 225 bindvars := make(map[string]struct{}) 226 _ = Walk(func(node SQLNode) (kontinue bool, err error) { 227 switch node := node.(type) { 228 case *ColName, TableName: 229 // Common node types that never contain expressions but create a lot of object 230 // allocations. 231 return false, nil 232 case Argument: 233 bindvars[string(node)] = struct{}{} 234 case ListArg: 235 bindvars[string(node)] = struct{}{} 236 } 237 return true, nil 238 }, stmt) 239 return bindvars 240 }