github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/unnest_exists_subqueries.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  	"github.com/dolthub/go-mysql-server/sql/expression"
    22  	"github.com/dolthub/go-mysql-server/sql/plan"
    23  	"github.com/dolthub/go-mysql-server/sql/transform"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  type aliasDisambiguator struct {
    28  	n                   sql.Node
    29  	scope               *plan.Scope
    30  	aliases             *TableAliases
    31  	disambiguationIndex int
    32  }
    33  
    34  func (ad *aliasDisambiguator) GetAliases() (TableAliases, error) {
    35  	if ad.aliases == nil {
    36  		aliases, err := getTableAliases(ad.n, ad.scope)
    37  		if err != nil {
    38  			return TableAliases{}, err
    39  		}
    40  		ad.aliases = &aliases
    41  	}
    42  	return *ad.aliases, nil
    43  }
    44  
    45  func (ad *aliasDisambiguator) Disambiguate(alias string) (string, error) {
    46  	nodeAliases, err := ad.GetAliases()
    47  	if err != nil {
    48  		return "", err
    49  	}
    50  
    51  	// all renamed aliases will be of the form <alias>_<disambiguationIndex++>
    52  	for {
    53  		ad.disambiguationIndex++
    54  		aliasName := fmt.Sprintf("%s_%d", alias, ad.disambiguationIndex)
    55  		if _, ok, err := nodeAliases.resolveName(aliasName); !ok {
    56  			if err != nil {
    57  				return "", err
    58  			}
    59  			return aliasName, nil
    60  		}
    61  	}
    62  }
    63  
    64  func newAliasDisambiguator(n sql.Node, scope *plan.Scope) *aliasDisambiguator {
    65  	return &aliasDisambiguator{n: n, scope: scope}
    66  }
    67  
    68  // unnestExistsSubqueries merges a WHERE EXISTS subquery scope with its outer
    69  // scope when the subquery filters on columns from the outer scope.
    70  //
    71  // For example:
    72  // select * from a where exists (select 1 from b where a.x = b.x)
    73  // =>
    74  // select * from a semi join b on a.x = b.x
    75  func unnestExistsSubqueries(
    76  	ctx *sql.Context,
    77  	a *Analyzer,
    78  	n sql.Node,
    79  	scope *plan.Scope,
    80  	sel RuleSelector,
    81  ) (sql.Node, transform.TreeIdentity, error) {
    82  	aliasDisambig := newAliasDisambiguator(n, scope)
    83  	return unnestSelectExistsHelper(ctx, scope, a, n, aliasDisambig)
    84  }
    85  
    86  func unnestSelectExistsHelper(ctx *sql.Context, scope *plan.Scope, a *Analyzer, n sql.Node, aliasDisambig *aliasDisambiguator) (sql.Node, transform.TreeIdentity, error) {
    87  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    88  		f, ok := n.(*plan.Filter)
    89  		if !ok {
    90  			return n, transform.SameTree, nil
    91  		}
    92  		return unnestExistSubqueries(ctx, scope, a, f, aliasDisambig)
    93  	})
    94  }
    95  
    96  // simplifyPartialJoinParents discards nodes that will not affect an existence check.
    97  func simplifyPartialJoinParents(n sql.Node) (sql.Node, bool) {
    98  	ret := n
    99  	for {
   100  		switch n := ret.(type) {
   101  		case *plan.Having:
   102  			return nil, false
   103  		case *plan.Project, *plan.GroupBy, *plan.Limit, *plan.Sort, *plan.Distinct, *plan.TopN:
   104  			ret = n.Children()[0]
   105  		default:
   106  			return ret, true
   107  		}
   108  	}
   109  }
   110  
   111  // unnestExistSubqueries scans a filter for [NOT] WHERE EXISTS, and then attempts to
   112  // extract the subquery, correlated filters, a modified outer scope (net subquery and filters),
   113  // and the new target joinType
   114  func unnestExistSubqueries(ctx *sql.Context, scope *plan.Scope, a *Analyzer, filter *plan.Filter, aliasDisambig *aliasDisambiguator) (sql.Node, transform.TreeIdentity, error) {
   115  	ret := filter.Child
   116  	var retFilters []sql.Expression
   117  	same := transform.SameTree
   118  	for _, f := range expression.SplitConjunction(filter.Expression) {
   119  		var s *hoistSubquery
   120  		var err error
   121  
   122  		// match subquery expression
   123  		joinType := plan.JoinTypeSemi
   124  		var sq *plan.Subquery
   125  		switch e := f.(type) {
   126  		case *plan.ExistsSubquery:
   127  			sq = e.Query
   128  		case *expression.Not:
   129  			if esq, ok := e.Child.(*plan.ExistsSubquery); ok {
   130  				sq = esq.Query
   131  				joinType = plan.JoinTypeAnti
   132  			}
   133  		default:
   134  		}
   135  		if sq == nil {
   136  			retFilters = append(retFilters, f)
   137  			continue
   138  		}
   139  
   140  		// try to decorrelate
   141  		s, err = decorrelateOuterCols(sq.Query, aliasDisambig, sq.Correlated())
   142  		if err != nil {
   143  			return nil, transform.SameTree, err
   144  		}
   145  
   146  		if s == nil {
   147  			retFilters = append(retFilters, f)
   148  			continue
   149  		}
   150  
   151  		// recurse
   152  		if s.inner != nil {
   153  			s.inner, _, err = unnestSelectExistsHelper(ctx, scope.NewScopeFromSubqueryExpression(filter, sq.Correlated()), a, s.inner, aliasDisambig)
   154  			if err != nil {
   155  				return nil, transform.SameTree, err
   156  			}
   157  		}
   158  
   159  		if sqa, ok := s.inner.(*plan.SubqueryAlias); ok {
   160  			if !sqa.CanCacheResults() {
   161  				return filter, transform.SameTree, nil
   162  			}
   163  		}
   164  
   165  		// if we reached here, |s| contains the state we need to
   166  		// decorrelate the subquery expression into a new node
   167  		same = transform.NewTree
   168  		var comment string
   169  		if c, ok := ret.(sql.CommentedNode); ok {
   170  			comment = c.Comment()
   171  		}
   172  
   173  		if s.emptyScope {
   174  			switch joinType {
   175  			case plan.JoinTypeAnti:
   176  				// ret will be all rows
   177  			case plan.JoinTypeSemi:
   178  				ret = plan.NewEmptyTableWithSchema(ret.Schema())
   179  			default:
   180  				return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type")
   181  			}
   182  			continue
   183  		}
   184  
   185  		if len(s.joinFilters) == 0 {
   186  			switch joinType {
   187  			case plan.JoinTypeAnti:
   188  				cond := expression.NewLiteral(true, types.Boolean)
   189  				ret = plan.NewAntiJoin(ret, s.inner, cond).WithComment(comment)
   190  
   191  			case plan.JoinTypeSemi:
   192  				ret = plan.NewCrossJoin(ret, s.inner).WithComment(comment)
   193  			default:
   194  				return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type")
   195  			}
   196  			continue
   197  		}
   198  
   199  		outerFilters := s.joinFilters
   200  		if referencesOuterScope(outerFilters, scope) {
   201  			retFilters = append(retFilters, f)
   202  			continue
   203  		}
   204  
   205  		switch joinType {
   206  		case plan.JoinTypeAnti:
   207  			ret = plan.NewAntiJoin(ret, s.inner, expression.JoinAnd(outerFilters...)).WithComment(comment)
   208  		case plan.JoinTypeSemi:
   209  			ret = plan.NewSemiJoin(ret, s.inner, expression.JoinAnd(outerFilters...)).WithComment(comment)
   210  		default:
   211  			return filter, transform.SameTree, fmt.Errorf("hoistSelectExists failed on unexpected join type")
   212  		}
   213  	}
   214  
   215  	if same {
   216  		return filter, transform.SameTree, nil
   217  	}
   218  	if len(retFilters) > 0 {
   219  		ret = plan.NewFilter(expression.JoinAnd(retFilters...), ret)
   220  	}
   221  	return ret, transform.NewTree, nil
   222  }
   223  
   224  // referencesOuterScope returns true if a filter in the set is from an outer scope
   225  func referencesOuterScope(filters []sql.Expression, scope *plan.Scope) bool {
   226  	if scope == nil {
   227  		return false
   228  	}
   229  	for _, e := range filters {
   230  		if transform.InspectExpr(e, func(e sql.Expression) bool {
   231  			gf, ok := e.(*expression.GetField)
   232  			return ok && scope.Correlated().Contains(gf.Id())
   233  		}) {
   234  			return true
   235  		}
   236  	}
   237  	return false
   238  }
   239  
   240  type hoistSubquery struct {
   241  	inner       sql.Node
   242  	joinFilters []sql.Expression
   243  	emptyScope  bool
   244  }
   245  
   246  type fakeNameable struct {
   247  	name string
   248  }
   249  
   250  var _ sql.Nameable = (*fakeNameable)(nil)
   251  
   252  func (f fakeNameable) Name() string { return f.name }
   253  
   254  // decorrelateOuterCols returns an optionally modified subquery and extracted filters referencing an outer scope.
   255  // If the subquery has aliases that conflict with outside aliases, the internal aliases will be renamed to avoid
   256  // name collisions.
   257  func decorrelateOuterCols(sqChild sql.Node, aliasDisambig *aliasDisambiguator, corr sql.ColSet) (*hoistSubquery, error) {
   258  	var joinFilters []sql.Expression
   259  	var filtersToKeep []sql.Expression
   260  	var emptyScope bool
   261  	var cantDecorrelate bool
   262  	n, _, _ := transform.Node(sqChild, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
   263  		if emptyScope {
   264  			return n, transform.SameTree, nil
   265  		}
   266  		switch f := n.(type) {
   267  		case *plan.Offset:
   268  			cantDecorrelate = true
   269  			return n, transform.SameTree, nil
   270  		case *plan.EmptyTable:
   271  			emptyScope = true
   272  			return n, transform.SameTree, nil
   273  		case *plan.Filter:
   274  			filters := expression.SplitConjunction(f.Expression)
   275  			for _, f := range filters {
   276  				outerRef := transform.InspectExpr(f, func(e sql.Expression) bool {
   277  					if gf, ok := e.(*expression.GetField); ok && corr.Contains(gf.Id()) {
   278  						return true
   279  					}
   280  					if sq, ok := e.(*plan.Subquery); ok {
   281  						if !sq.Correlated().Intersection(corr).Empty() {
   282  							return true
   283  						}
   284  					}
   285  					return false
   286  				})
   287  
   288  				// based on the GetField analysis, decide where to put the filter
   289  				if outerRef {
   290  					joinFilters = append(joinFilters, f)
   291  				} else {
   292  					filtersToKeep = append(filtersToKeep, f)
   293  				}
   294  			}
   295  
   296  			// avoid updating the tree if we don't move any filters
   297  			if len(filtersToKeep) == len(filters) {
   298  				filtersToKeep = nil
   299  				return f, transform.SameTree, nil
   300  			}
   301  
   302  			return f.Child, transform.NewTree, nil
   303  		default:
   304  			return n, transform.SameTree, nil
   305  		}
   306  	})
   307  
   308  	if emptyScope {
   309  		return &hoistSubquery{
   310  			emptyScope: true,
   311  		}, nil
   312  	}
   313  
   314  	if cantDecorrelate {
   315  		return nil, nil
   316  	}
   317  
   318  	nodeAliases, err := getTableAliases(n, nil)
   319  	if err != nil {
   320  		return nil, err
   321  	}
   322  
   323  	outsideAliases, err := aliasDisambig.GetAliases()
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  	conflicts, nonConflicted := outsideAliases.findConflicts(nodeAliases)
   328  	for _, goodAlias := range nonConflicted {
   329  		target, ok, err := nodeAliases.resolveName(goodAlias)
   330  		if err != nil {
   331  			return nil, err
   332  		}
   333  		if !ok {
   334  			return nil, fmt.Errorf("node alias %s is not in nodeAliases", goodAlias)
   335  		}
   336  		err = outsideAliases.addUnqualified(goodAlias, target)
   337  		if err != nil {
   338  			return nil, err
   339  		}
   340  	}
   341  
   342  	if len(conflicts) > 0 {
   343  		for _, conflict := range conflicts {
   344  
   345  			// conflict, need to rename
   346  			newAlias, err := aliasDisambig.Disambiguate(conflict)
   347  			if err != nil {
   348  				return nil, err
   349  			}
   350  			same := transform.SameTree
   351  			n, same, err = renameAliases(n, conflict, newAlias)
   352  			if err != nil {
   353  				return nil, err
   354  			}
   355  
   356  			if same {
   357  				return nil, fmt.Errorf("tree is unchanged after attempted rename")
   358  			}
   359  
   360  			// rename the aliases in the expressions
   361  			joinFilters, err = renameAliasesInExpressions(joinFilters, conflict, newAlias)
   362  			if err != nil {
   363  				return nil, err
   364  			}
   365  
   366  			filtersToKeep, err = renameAliasesInExpressions(filtersToKeep, conflict, newAlias)
   367  			if err != nil {
   368  				return nil, err
   369  			}
   370  
   371  			// alias was renamed, need to get the renamed target before adding to the outside aliases collection
   372  			nodeAliases, err = getTableAliases(n, nil)
   373  			if err != nil {
   374  				return nil, err
   375  			}
   376  
   377  			// retrieve the new target
   378  			target, ok, err := nodeAliases.resolveName(newAlias)
   379  			if err != nil {
   380  				return nil, err
   381  			}
   382  			if !ok {
   383  				return nil, fmt.Errorf("node alias %s is not in nodeAliases", newAlias)
   384  			}
   385  
   386  			// add the new target to the outside aliases collection
   387  			err = outsideAliases.addUnqualified(newAlias, target)
   388  			if err != nil {
   389  				return nil, err
   390  			}
   391  		}
   392  	}
   393  
   394  	n, ok := simplifyPartialJoinParents(n)
   395  	if !ok {
   396  		return nil, nil
   397  	}
   398  	if len(filtersToKeep) > 0 {
   399  		n = plan.NewFilter(expression.JoinAnd(filtersToKeep...), n)
   400  	}
   401  
   402  	if len(joinFilters) == 0 {
   403  		n = plan.NewLimit(expression.NewLiteral(1, types.Int64), n)
   404  	}
   405  
   406  	return &hoistSubquery{
   407  		inner:       n,
   408  		joinFilters: joinFilters,
   409  	}, nil
   410  }