vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/expr.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package planbuilder 18 19 import ( 20 "bytes" 21 "fmt" 22 23 "vitess.io/vitess/go/vt/vterrors" 24 25 "vitess.io/vitess/go/vt/sqlparser" 26 "vitess.io/vitess/go/vt/vtgate/engine" 27 "vitess.io/vitess/go/vt/vtgate/vindexes" 28 ) 29 30 type subqueryInfo struct { 31 ast *sqlparser.Subquery 32 plan logicalPlan 33 origin logicalPlan 34 } 35 36 // findOrigin identifies the right-most origin referenced by expr. In situations where 37 // the expression references columns from multiple origins, the expression will be 38 // pushed to the right-most origin, and the executor will use the results of 39 // the previous origins to feed the necessary values to the primitives on the right. 40 // 41 // If the expression contains a subquery, the right-most origin identification 42 // also follows the same rules of a normal expression. This is achieved by 43 // looking at the Externs field of its symbol table that contains the list of 44 // external references. 45 // 46 // Once the target origin is identified, we have to verify that the subquery's 47 // route can be merged with it. If it cannot, we fail the query. This is because 48 // we don't have the ability to wire up subqueries through expression evaluation 49 // primitives. Consequently, if the plan for a subquery comes out as a Join, 50 // we can immediately error out. 51 // 52 // Since findOrigin can itself be called from within a subquery, it has to assume 53 // that some of the external references may actually be pointing to an outer 54 // query. The isLocal response from the symtab is used to make sure that we 55 // only analyze symbols that point to the current symtab. 56 // 57 // If an expression has no references to the current query, then the left-most 58 // origin is chosen as the default. 59 func (pb *primitiveBuilder) findOrigin(expr sqlparser.Expr, reservedVars *sqlparser.ReservedVars) (pullouts []*pulloutSubquery, origin logicalPlan, pushExpr sqlparser.Expr, err error) { 60 // highestOrigin tracks the highest origin referenced by the expression. 61 // Default is the first. 62 highestOrigin := first(pb.plan) 63 64 // subqueries tracks the list of subqueries encountered. 65 var subqueries []subqueryInfo 66 67 // constructsMap tracks the sub-construct in which a subquery 68 // occurred. The construct type decides on how the query gets 69 // pulled out. 70 constructsMap := make(map[*sqlparser.Subquery]sqlparser.Expr) 71 72 err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 73 switch node := node.(type) { 74 case *sqlparser.ColName: 75 newOrigin, isLocal, err := pb.st.Find(node) 76 if err != nil { 77 return false, err 78 } 79 if isLocal && newOrigin.Order() > highestOrigin.Order() { 80 highestOrigin = newOrigin 81 } 82 case *sqlparser.ComparisonExpr: 83 if node.Operator == sqlparser.InOp || node.Operator == sqlparser.NotInOp { 84 if sq, ok := node.Right.(*sqlparser.Subquery); ok { 85 constructsMap[sq] = node 86 } 87 } 88 case *sqlparser.ExistsExpr: 89 constructsMap[node.Subquery] = node 90 case *sqlparser.Subquery: 91 spb := newPrimitiveBuilder(pb.vschema, pb.jt) 92 switch stmt := node.Select.(type) { 93 case *sqlparser.Select: 94 if err := spb.processSelect(stmt, reservedVars, pb.st, ""); err != nil { 95 return false, err 96 } 97 case *sqlparser.Union: 98 if err := spb.processUnion(stmt, reservedVars, pb.st); err != nil { 99 return false, err 100 } 101 default: 102 return false, vterrors.VT13001(fmt.Sprintf("unexpected SELECT type: %T", node)) 103 } 104 sqi := subqueryInfo{ 105 ast: node, 106 plan: spb.plan, 107 } 108 for _, extern := range spb.st.Externs { 109 // No error expected. These are resolved externs. 110 newOrigin, isLocal, _ := pb.st.Find(extern) 111 if !isLocal { 112 continue 113 } 114 if highestOrigin.Order() < newOrigin.Order() { 115 highestOrigin = newOrigin 116 } 117 if sqi.origin == nil { 118 sqi.origin = newOrigin 119 } else if sqi.origin.Order() < newOrigin.Order() { 120 sqi.origin = newOrigin 121 } 122 } 123 subqueries = append(subqueries, sqi) 124 return false, nil 125 } 126 return true, nil 127 }, expr) 128 if err != nil { 129 return nil, nil, nil, err 130 } 131 132 highestRoute, _ := highestOrigin.(*route) 133 for _, sqi := range subqueries { 134 subroute, _ := sqi.plan.(*route) 135 if highestRoute != nil && subroute != nil && highestRoute.MergeSubquery(pb, subroute) { 136 continue 137 } 138 if sqi.origin != nil { 139 return nil, nil, nil, vterrors.VT12001("cross-shard correlated subquery") 140 } 141 142 sqName, hasValues := pb.jt.GenerateSubqueryVars() 143 construct, ok := constructsMap[sqi.ast] 144 if !ok { 145 // (subquery) -> :_sq 146 expr = sqlparser.ReplaceExpr(expr, sqi.ast, sqlparser.NewArgument(sqName)) 147 pullouts = append(pullouts, newPulloutSubquery(engine.PulloutValue, sqName, hasValues, sqi.plan)) 148 continue 149 } 150 switch construct := construct.(type) { 151 case *sqlparser.ComparisonExpr: 152 if construct.Operator == sqlparser.InOp { 153 // a in (subquery) -> (:__sq_has_values = 1 and (a in ::__sq)) 154 right := &sqlparser.ComparisonExpr{ 155 Operator: construct.Operator, 156 Left: construct.Left, 157 Right: sqlparser.ListArg(sqName), 158 } 159 left := &sqlparser.ComparisonExpr{ 160 Left: sqlparser.NewArgument(hasValues), 161 Operator: sqlparser.EqualOp, 162 Right: sqlparser.NewIntLiteral("1"), 163 } 164 newExpr := &sqlparser.AndExpr{ 165 Left: left, 166 Right: right, 167 } 168 expr = sqlparser.ReplaceExpr(expr, construct, newExpr) 169 pullouts = append(pullouts, newPulloutSubquery(engine.PulloutIn, sqName, hasValues, sqi.plan)) 170 } else { 171 // a not in (subquery) -> (:__sq_has_values = 0 or (a not in ::__sq)) 172 left := &sqlparser.ComparisonExpr{ 173 Left: sqlparser.NewArgument(hasValues), 174 Operator: sqlparser.EqualOp, 175 Right: sqlparser.NewIntLiteral("0"), 176 } 177 right := &sqlparser.ComparisonExpr{ 178 Operator: construct.Operator, 179 Left: construct.Left, 180 Right: sqlparser.ListArg(sqName), 181 } 182 newExpr := &sqlparser.OrExpr{ 183 Left: left, 184 Right: right, 185 } 186 expr = sqlparser.ReplaceExpr(expr, construct, newExpr) 187 pullouts = append(pullouts, newPulloutSubquery(engine.PulloutNotIn, sqName, hasValues, sqi.plan)) 188 } 189 case *sqlparser.ExistsExpr: 190 // exists (subquery) -> :__sq_has_values 191 expr = sqlparser.ReplaceExpr(expr, construct, sqlparser.NewArgument(hasValues)) 192 pullouts = append(pullouts, newPulloutSubquery(engine.PulloutExists, sqName, hasValues, sqi.plan)) 193 } 194 } 195 return pullouts, highestOrigin, expr, nil 196 } 197 198 var dummyErr = vterrors.VT13001("dummy") 199 200 func hasSubquery(node sqlparser.SQLNode) bool { 201 has := false 202 _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 203 switch node.(type) { 204 case *sqlparser.DerivedTable, *sqlparser.Subquery: 205 has = true 206 return false, dummyErr 207 } 208 return true, nil 209 }, node) 210 return has 211 } 212 213 func (pb *primitiveBuilder) finalizeUnshardedDMLSubqueries(reservedVars *sqlparser.ReservedVars, nodes ...sqlparser.SQLNode) (bool, []*vindexes.Table) { 214 var keyspace string 215 var tables []*vindexes.Table 216 if rb, ok := pb.plan.(*route); ok { 217 keyspace = rb.eroute.Keyspace.Name 218 } else { 219 // This code is unreachable because the caller checks. 220 return false, nil 221 } 222 223 for _, node := range nodes { 224 samePlan := true 225 inSubQuery := false 226 _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 227 switch nodeType := node.(type) { 228 case *sqlparser.Subquery, *sqlparser.Insert: 229 inSubQuery = true 230 return true, nil 231 case *sqlparser.Select: 232 if !inSubQuery { 233 return true, nil 234 } 235 spb := newPrimitiveBuilder(pb.vschema, pb.jt) 236 if err := spb.processSelect(nodeType, reservedVars, pb.st, ""); err != nil { 237 samePlan = false 238 return false, err 239 } 240 innerRoute, ok := spb.plan.(*route) 241 if !ok { 242 samePlan = false 243 return false, dummyErr 244 } 245 if innerRoute.eroute.Keyspace.Name != keyspace { 246 samePlan = false 247 return false, dummyErr 248 } 249 for _, sub := range innerRoute.substitutions { 250 *sub.oldExpr = *sub.newExpr 251 } 252 spbTables, err := spb.st.AllVschemaTableNames() 253 if err != nil { 254 return false, err 255 } 256 tables = append(tables, spbTables...) 257 case *sqlparser.Union: 258 if !inSubQuery { 259 return true, nil 260 } 261 spb := newPrimitiveBuilder(pb.vschema, pb.jt) 262 if err := spb.processUnion(nodeType, reservedVars, pb.st); err != nil { 263 samePlan = false 264 return false, err 265 } 266 innerRoute, ok := spb.plan.(*route) 267 if !ok { 268 samePlan = false 269 return false, dummyErr 270 } 271 if innerRoute.eroute.Keyspace.Name != keyspace { 272 samePlan = false 273 return false, dummyErr 274 } 275 } 276 277 return true, nil 278 }, node) 279 if !samePlan { 280 return false, nil 281 } 282 } 283 return true, tables 284 } 285 286 func valEqual(a, b sqlparser.Expr) bool { 287 switch a := a.(type) { 288 case *sqlparser.ColName: 289 if b, ok := b.(*sqlparser.ColName); ok { 290 return a.Metadata == b.Metadata 291 } 292 case sqlparser.Argument: 293 b, ok := b.(sqlparser.Argument) 294 if !ok { 295 return false 296 } 297 return a == b 298 case *sqlparser.Literal: 299 b, ok := b.(*sqlparser.Literal) 300 if !ok { 301 return false 302 } 303 switch a.Type { 304 case sqlparser.StrVal: 305 switch b.Type { 306 case sqlparser.StrVal: 307 return a.Val == b.Val 308 case sqlparser.HexVal: 309 return hexEqual(b, a) 310 } 311 case sqlparser.HexVal: 312 return hexEqual(a, b) 313 case sqlparser.IntVal: 314 if b.Type == (sqlparser.IntVal) { 315 return a.Val == b.Val 316 } 317 } 318 } 319 return false 320 } 321 322 func hexEqual(a, b *sqlparser.Literal) bool { 323 v, err := a.HexDecode() 324 if err != nil { 325 return false 326 } 327 switch b.Type { 328 case sqlparser.StrVal: 329 return bytes.Equal(v, b.Bytes()) 330 case sqlparser.HexVal: 331 v2, err := b.HexDecode() 332 if err != nil { 333 return false 334 } 335 return bytes.Equal(v, v2) 336 } 337 return false 338 }