github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/unnest_exists_subqueries.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 "github.com/dolthub/go-mysql-server/sql/expression" 22 "github.com/dolthub/go-mysql-server/sql/plan" 23 "github.com/dolthub/go-mysql-server/sql/transform" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 ) 26 27 type aliasDisambiguator struct { 28 n sql.Node 29 scope *plan.Scope 30 aliases *TableAliases 31 disambiguationIndex int 32 } 33 34 func (ad *aliasDisambiguator) GetAliases() (TableAliases, error) { 35 if ad.aliases == nil { 36 aliases, err := getTableAliases(ad.n, ad.scope) 37 if err != nil { 38 return TableAliases{}, err 39 } 40 ad.aliases = &aliases 41 } 42 return *ad.aliases, nil 43 } 44 45 func (ad *aliasDisambiguator) Disambiguate(alias string) (string, error) { 46 nodeAliases, err := ad.GetAliases() 47 if err != nil { 48 return "", err 49 } 50 51 // all renamed aliases will be of the form <alias>_<disambiguationIndex++> 52 for { 53 ad.disambiguationIndex++ 54 aliasName := fmt.Sprintf("%s_%d", alias, ad.disambiguationIndex) 55 if _, ok, err := nodeAliases.resolveName(aliasName); !ok { 56 if err != nil { 57 return "", err 58 } 59 return aliasName, nil 60 } 61 } 62 } 63 64 func newAliasDisambiguator(n sql.Node, scope *plan.Scope) *aliasDisambiguator { 65 return &aliasDisambiguator{n: n, scope: scope} 66 } 67 68 // unnestExistsSubqueries merges a WHERE EXISTS subquery scope with its outer 69 // scope when the subquery filters on columns from the outer scope. 70 // 71 // For example: 72 // select * from a where exists (select 1 from b where a.x = b.x) 73 // => 74 // select * from a semi join b on a.x = b.x 75 func unnestExistsSubqueries( 76 ctx *sql.Context, 77 a *Analyzer, 78 n sql.Node, 79 scope *plan.Scope, 80 sel RuleSelector, 81 ) (sql.Node, transform.TreeIdentity, error) { 82 aliasDisambig := newAliasDisambiguator(n, scope) 83 return unnestSelectExistsHelper(ctx, scope, a, n, aliasDisambig) 84 } 85 86 func unnestSelectExistsHelper(ctx *sql.Context, scope *plan.Scope, a *Analyzer, n sql.Node, aliasDisambig *aliasDisambiguator) (sql.Node, transform.TreeIdentity, error) { 87 return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 88 f, ok := n.(*plan.Filter) 89 if !ok { 90 return n, transform.SameTree, nil 91 } 92 return unnestExistSubqueries(ctx, scope, a, f, aliasDisambig) 93 }) 94 } 95 96 // simplifyPartialJoinParents discards nodes that will not affect an existence check. 97 func simplifyPartialJoinParents(n sql.Node) (sql.Node, bool) { 98 ret := n 99 for { 100 switch n := ret.(type) { 101 case *plan.Having: 102 return nil, false 103 case *plan.Project, *plan.GroupBy, *plan.Limit, *plan.Sort, *plan.Distinct, *plan.TopN: 104 ret = n.Children()[0] 105 default: 106 return ret, true 107 } 108 } 109 } 110 111 // unnestExistSubqueries scans a filter for [NOT] WHERE EXISTS, and then attempts to 112 // extract the subquery, correlated filters, a modified outer scope (net subquery and filters), 113 // and the new target joinType 114 func unnestExistSubqueries(ctx *sql.Context, scope *plan.Scope, a *Analyzer, filter *plan.Filter, aliasDisambig *aliasDisambiguator) (sql.Node, transform.TreeIdentity, error) { 115 ret := filter.Child 116 var retFilters []sql.Expression 117 same := transform.SameTree 118 for _, f := range expression.SplitConjunction(filter.Expression) { 119 var s *hoistSubquery 120 var err error 121 122 // match subquery expression 123 joinType := plan.JoinTypeSemi 124 var sq *plan.Subquery 125 switch e := f.(type) { 126 case *plan.ExistsSubquery: 127 sq = e.Query 128 case *expression.Not: 129 if esq, ok := e.Child.(*plan.ExistsSubquery); ok { 130 sq = esq.Query 131 joinType = plan.JoinTypeAnti 132 } 133 default: 134 } 135 if sq == nil { 136 retFilters = append(retFilters, f) 137 continue 138 } 139 140 // try to decorrelate 141 s, err = decorrelateOuterCols(sq.Query, aliasDisambig, sq.Correlated()) 142 if err != nil { 143 return nil, transform.SameTree, err 144 } 145 146 if s == nil { 147 retFilters = append(retFilters, f) 148 continue 149 } 150 151 // recurse 152 if s.inner != nil { 153 s.inner, _, err = unnestSelectExistsHelper(ctx, scope.NewScopeFromSubqueryExpression(filter, sq.Correlated()), a, s.inner, aliasDisambig) 154 if err != nil { 155 return nil, transform.SameTree, err 156 } 157 } 158 159 if sqa, ok := s.inner.(*plan.SubqueryAlias); ok { 160 if !sqa.CanCacheResults() { 161 return filter, transform.SameTree, nil 162 } 163 } 164 165 // if we reached here, |s| contains the state we need to 166 // decorrelate the subquery expression into a new node 167 same = transform.NewTree 168 var comment string 169 if c, ok := ret.(sql.CommentedNode); ok { 170 comment = c.Comment() 171 } 172 173 if s.emptyScope { 174 switch joinType { 175 case plan.JoinTypeAnti: 176 // ret will be all rows 177 case plan.JoinTypeSemi: 178 ret = plan.NewEmptyTableWithSchema(ret.Schema()) 179 default: 180 return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type") 181 } 182 continue 183 } 184 185 if len(s.joinFilters) == 0 { 186 switch joinType { 187 case plan.JoinTypeAnti: 188 cond := expression.NewLiteral(true, types.Boolean) 189 ret = plan.NewAntiJoin(ret, s.inner, cond).WithComment(comment) 190 191 case plan.JoinTypeSemi: 192 ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment) 193 default: 194 return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type") 195 } 196 continue 197 } 198 199 outerFilters := s.joinFilters 200 if referencesOuterScope(outerFilters, scope) { 201 retFilters = append(retFilters, f) 202 continue 203 } 204 205 switch joinType { 206 case plan.JoinTypeAnti: 207 ret = plan.NewAntiJoin(ret, s.inner, expression.JoinAnd(outerFilters...)).WithComment(comment) 208 case plan.JoinTypeSemi: 209 ret = plan.NewSemiJoin(ret, s.inner, expression.JoinAnd(outerFilters...)).WithComment(comment) 210 default: 211 return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type") 212 } 213 } 214 215 if same { 216 return filter, transform.SameTree, nil 217 } 218 if len(retFilters) > 0 { 219 ret = plan.NewFilter(expression.JoinAnd(retFilters...), ret) 220 } 221 return ret, transform.NewTree, nil 222 } 223 224 // referencesOuterScope returns true if a filter in the set is from an outer scope 225 func referencesOuterScope(filters []sql.Expression, scope *plan.Scope) bool { 226 if scope == nil { 227 return false 228 } 229 for _, e := range filters { 230 if transform.InspectExpr(e, func(e sql.Expression) bool { 231 gf, ok := e.(*expression.GetField) 232 return ok && scope.Correlated().Contains(gf.Id()) 233 }) { 234 return true 235 } 236 } 237 return false 238 } 239 240 type hoistSubquery struct { 241 inner sql.Node 242 joinFilters []sql.Expression 243 emptyScope bool 244 } 245 246 type fakeNameable struct { 247 name string 248 } 249 250 var _ sql.Nameable = (*fakeNameable)(nil) 251 252 func (f fakeNameable) Name() string { return f.name } 253 254 // decorrelateOuterCols returns an optionally modified subquery and extracted filters referencing an outer scope. 255 // If the subquery has aliases that conflict with outside aliases, the internal aliases will be renamed to avoid 256 // name collisions. 257 func decorrelateOuterCols(sqChild sql.Node, aliasDisambig *aliasDisambiguator, corr sql.ColSet) (*hoistSubquery, error) { 258 var joinFilters []sql.Expression 259 var filtersToKeep []sql.Expression 260 var emptyScope bool 261 var cantDecorrelate bool 262 n, _, _ := transform.Node(sqChild, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 263 if emptyScope { 264 return n, transform.SameTree, nil 265 } 266 switch f := n.(type) { 267 case *plan.Offset: 268 cantDecorrelate = true 269 return n, transform.SameTree, nil 270 case *plan.EmptyTable: 271 emptyScope = true 272 return n, transform.SameTree, nil 273 case *plan.Filter: 274 filters := expression.SplitConjunction(f.Expression) 275 for _, f := range filters { 276 outerRef := transform.InspectExpr(f, func(e sql.Expression) bool { 277 if gf, ok := e.(*expression.GetField); ok && corr.Contains(gf.Id()) { 278 return true 279 } 280 if sq, ok := e.(*plan.Subquery); ok { 281 if !sq.Correlated().Intersection(corr).Empty() { 282 return true 283 } 284 } 285 return false 286 }) 287 288 // based on the GetField analysis, decide where to put the filter 289 if outerRef { 290 joinFilters = append(joinFilters, f) 291 } else { 292 filtersToKeep = append(filtersToKeep, f) 293 } 294 } 295 296 // avoid updating the tree if we don't move any filters 297 if len(filtersToKeep) == len(filters) { 298 filtersToKeep = nil 299 return f, transform.SameTree, nil 300 } 301 302 return f.Child, transform.NewTree, nil 303 default: 304 return n, transform.SameTree, nil 305 } 306 }) 307 308 if emptyScope { 309 return &hoistSubquery{ 310 emptyScope: true, 311 }, nil 312 } 313 314 if cantDecorrelate { 315 return nil, nil 316 } 317 318 nodeAliases, err := getTableAliases(n, nil) 319 if err != nil { 320 return nil, err 321 } 322 323 outsideAliases, err := aliasDisambig.GetAliases() 324 if err != nil { 325 return nil, err 326 } 327 conflicts, nonConflicted := outsideAliases.findConflicts(nodeAliases) 328 for _, goodAlias := range nonConflicted { 329 target, ok, err := nodeAliases.resolveName(goodAlias) 330 if err != nil { 331 return nil, err 332 } 333 if !ok { 334 return nil, fmt.Errorf("node alias %s is not in nodeAliases", goodAlias) 335 } 336 err = outsideAliases.addUnqualified(goodAlias, target) 337 if err != nil { 338 return nil, err 339 } 340 } 341 342 if len(conflicts) > 0 { 343 for _, conflict := range conflicts { 344 345 // conflict, need to rename 346 newAlias, err := aliasDisambig.Disambiguate(conflict) 347 if err != nil { 348 return nil, err 349 } 350 same := transform.SameTree 351 n, same, err = renameAliases(n, conflict, newAlias) 352 if err != nil { 353 return nil, err 354 } 355 356 if same { 357 return nil, fmt.Errorf("tree is unchanged after attempted rename") 358 } 359 360 // rename the aliases in the expressions 361 joinFilters, err = renameAliasesInExpressions(joinFilters, conflict, newAlias) 362 if err != nil { 363 return nil, err 364 } 365 366 filtersToKeep, err = renameAliasesInExpressions(filtersToKeep, conflict, newAlias) 367 if err != nil { 368 return nil, err 369 } 370 371 // alias was renamed, need to get the renamed target before adding to the outside aliases collection 372 nodeAliases, err = getTableAliases(n, nil) 373 if err != nil { 374 return nil, err 375 } 376 377 // retrieve the new target 378 target, ok, err := nodeAliases.resolveName(newAlias) 379 if err != nil { 380 return nil, err 381 } 382 if !ok { 383 return nil, fmt.Errorf("node alias %s is not in nodeAliases", newAlias) 384 } 385 386 // add the new target to the outside aliases collection 387 err = outsideAliases.addUnqualified(newAlias, target) 388 if err != nil { 389 return nil, err 390 } 391 } 392 } 393 394 n, ok := simplifyPartialJoinParents(n) 395 if !ok { 396 return nil, nil 397 } 398 if len(filtersToKeep) > 0 { 399 n = plan.NewFilter(expression.JoinAnd(filtersToKeep...), n) 400 } 401 402 if len(joinFilters) == 0 { 403 n = plan.NewLimit(expression.NewLiteral(1, types.Int64), n) 404 } 405 406 return &hoistSubquery{ 407 inner: n, 408 joinFilters: joinFilters, 409 }, nil 410 }