github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/unnest_insubqueries.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  	"github.com/dolthub/go-mysql-server/sql/transform"
    25  )
    26  
    27  type applyJoin struct {
    28  	l        sql.Expression
    29  	r        *plan.Subquery
    30  	op       plan.JoinType
    31  	filter   sql.Expression
    32  	original sql.Expression
    33  	max1     bool
    34  }
    35  
    36  // unnestInSubqueries converts expression.Comparer with an *plan.InSubquery
    37  // RHS into joins. The match conditions include: 1) subquery is cacheable,
    38  // 2) the top-level subquery projection is a get field with a sql.ColumnId
    39  // and sql.TableId (to support join reordering).
    40  // TODO decorrelate lhs too
    41  // TODO non-null-rejecting with dual table
    42  func unnestInSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    43  	switch n.(type) {
    44  	case *plan.DeleteFrom, *plan.InsertInto:
    45  		return n, transform.SameTree, nil
    46  	}
    47  
    48  	var unnested bool
    49  	var aliases map[string]int
    50  
    51  	ret := n
    52  	var err error
    53  	same := transform.NewTree
    54  	for !same {
    55  		// simplifySubqExpr can merge two scopes, requiring us to either
    56  		// recurse on the merged scope or perform a fixed-point iteration.
    57  		ret, same, err = transform.Node(ret, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    58  			var filters []sql.Expression
    59  			var child sql.Node
    60  			switch n := n.(type) {
    61  			case *plan.Filter:
    62  				child = n.Child
    63  				filters = expression.SplitConjunction(n.Expression)
    64  			default:
    65  			}
    66  
    67  			if sel == nil {
    68  				return n, transform.SameTree, nil
    69  			}
    70  
    71  			var matches []applyJoin
    72  			var newFilters []sql.Expression
    73  
    74  			// separate decorrelation candidates
    75  			for _, e := range filters {
    76  				if !plan.IsNullRejecting(e) {
    77  					// TODO: rewrite dual table to permit in-scope joins,
    78  					// which aren't possible when values are projected
    79  					// above join filter
    80  					rt := getResolvedTable(n)
    81  					if rt == nil || plan.IsDualTable(rt.Table) {
    82  						newFilters = append(newFilters, e)
    83  						continue
    84  					}
    85  				}
    86  
    87  				candE := e
    88  				op := plan.JoinTypeSemi
    89  				if n, ok := e.(*expression.Not); ok {
    90  					candE = n.Child
    91  					op = plan.JoinTypeAnti
    92  				}
    93  
    94  				var sq *plan.Subquery
    95  				var l sql.Expression
    96  				var joinF sql.Expression
    97  				var max1 bool
    98  				switch e := candE.(type) {
    99  				case *plan.InSubquery:
   100  					sq, _ = e.RightChild.(*plan.Subquery)
   101  					l = e.LeftChild
   102  
   103  					joinF = expression.NewEquals(nil, nil)
   104  				case expression.Comparer:
   105  					sq, _ = e.Right().(*plan.Subquery)
   106  					l = e.Left()
   107  					joinF = e
   108  					max1 = true
   109  				default:
   110  				}
   111  				if sq != nil && sq.CanCacheResults() {
   112  					matches = append(matches, applyJoin{l: l, r: sq, op: op, filter: joinF, max1: max1, original: candE})
   113  				} else {
   114  					newFilters = append(newFilters, e)
   115  				}
   116  			}
   117  			if len(matches) == 0 {
   118  				return n, transform.SameTree, nil
   119  			}
   120  
   121  			ret := child
   122  			for _, m := range matches {
   123  				// A successful candidate is built with:
   124  				// (1) Semi or anti join between the outer scope and (2) conditioned on (3).
   125  				// (2) Simplified or unnested subquery (table alias).
   126  				// (3) Join condition synthesized from the original correlated expression
   127  				//     normalized to match changes to (2).
   128  				subq := m.r
   129  
   130  				if aliases == nil {
   131  					aliases = make(map[string]int)
   132  					ta, err := getTableAliases(n, scope)
   133  					if err != nil {
   134  						return n, transform.SameTree, err
   135  					}
   136  					for k, _ := range ta.aliases {
   137  						aliases[k] = 0
   138  					}
   139  				}
   140  
   141  				var newSubq sql.Node
   142  				newSubq, aliases, err = disambiguateTables(aliases, subq.Query)
   143  				if err != nil {
   144  					return ret, transform.SameTree, nil
   145  				}
   146  
   147  				rightF, ok, err := getHighestProjection(newSubq)
   148  				if err != nil {
   149  					return n, transform.SameTree, err
   150  				}
   151  				if !ok {
   152  					newFilters = append(newFilters, m.original)
   153  					continue
   154  				}
   155  
   156  				filter, err := m.filter.WithChildren(m.l, rightF)
   157  				if err != nil {
   158  					return n, transform.SameTree, err
   159  				}
   160  				var comment string
   161  				if c, ok := ret.(sql.CommentedNode); ok {
   162  					comment = c.Comment()
   163  				}
   164  				unnested = true
   165  				newJoin := plan.NewJoin(ret, newSubq, m.op, filter)
   166  				ret = newJoin.WithComment(comment)
   167  			}
   168  
   169  			if len(newFilters) == 0 {
   170  				return ret, transform.NewTree, nil
   171  			}
   172  			if len(newFilters) == len(filters) {
   173  				return n, transform.SameTree, nil
   174  			}
   175  			return plan.NewFilter(expression.JoinAnd(newFilters...), ret), transform.NewTree, nil
   176  		})
   177  		if err != nil {
   178  			return n, transform.SameTree, err
   179  		}
   180  	}
   181  	return ret, transform.TreeIdentity(!unnested), nil
   182  }
   183  
   184  // returns an updated sql.Node with aliases de-duplicated, and an
   185  // updated alias mapping with new conflicts and tables added.
   186  func disambiguateTables(used map[string]int, n sql.Node) (sql.Node, map[string]int, error) {
   187  	rename := make(map[sql.TableId]string)
   188  	n, _, err := transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
   189  		switch n := c.Node.(type) {
   190  		case sql.RenameableNode:
   191  			name := strings.ToLower(n.Name())
   192  			if _, ok := c.Parent.(sql.RenameableNode); ok {
   193  				// skip checking when: TableAlias(ResolvedTable)
   194  				return n, transform.SameTree, nil
   195  			}
   196  			if cnt, ok := used[name]; ok {
   197  				used[name] = cnt + 1
   198  				newName := name
   199  				for ok {
   200  					cnt++
   201  					newName = fmt.Sprintf("%s_%d", name, cnt)
   202  					_, ok = used[newName]
   203  
   204  				}
   205  				used[newName] = 0
   206  
   207  				tin, ok := n.(plan.TableIdNode)
   208  				if !ok {
   209  					return n, transform.SameTree, fmt.Errorf("expected sql.Renameable to implement plan.TableIdNode")
   210  				}
   211  				rename[tin.Id()] = newName
   212  				return n.WithName(newName), transform.NewTree, nil
   213  			} else {
   214  				used[name] = 0
   215  			}
   216  			return n, transform.NewTree, nil
   217  		default:
   218  			return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   219  				switch e := e.(type) {
   220  				case *expression.GetField:
   221  					if cnt, ok := used[strings.ToLower(e.Table())]; ok && cnt > 0 {
   222  						return e.WithTable(fmt.Sprintf("%s_%d", e.Table(), cnt)), transform.NewTree, nil
   223  					}
   224  				default:
   225  				}
   226  				return e, transform.NewTree, nil
   227  			})
   228  		}
   229  	})
   230  	if err != nil {
   231  		return nil, nil, err
   232  	}
   233  	if len(rename) > 0 {
   234  		n, _, err = renameExpressionTables(n, rename)
   235  	}
   236  	return n, used, err
   237  }
   238  
   239  // renameExpressionTables renames table references recursively. We use
   240  // table ids to avoid improperly renaming tables in lower scopes with the
   241  // same name.
   242  func renameExpressionTables(n sql.Node, rename map[sql.TableId]string) (sql.Node, transform.TreeIdentity, error) {
   243  	return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   244  		switch e := e.(type) {
   245  		case *expression.GetField:
   246  			if to, ok := rename[e.TableId()]; ok {
   247  				return e.WithTable(to), transform.NewTree, nil
   248  			}
   249  		case *plan.Subquery:
   250  			newQ, same, err := renameExpressionTables(e.Query, rename)
   251  			if !same || err != nil {
   252  				return e, same, err
   253  			}
   254  			return e.WithQuery(newQ), transform.NewTree, nil
   255  		default:
   256  		}
   257  		return e, transform.NewTree, nil
   258  	})
   259  }
   260  
   261  // getHighestProjection returns a set of projection expressions responsible
   262  // for the input node's schema, or false if an aggregate or set type is
   263  // found (which we cannot generate named projections for yet).
   264  func getHighestProjection(n sql.Node) (sql.Expression, bool, error) {
   265  	sch := n.Schema()
   266  	for n != nil {
   267  		if !sch.Equals(n.Schema()) {
   268  			break
   269  		}
   270  		var proj []sql.Expression
   271  		switch nn := n.(type) {
   272  		case *plan.Project:
   273  			proj = nn.Projections
   274  		case *plan.JoinNode:
   275  			left, ok, err := getHighestProjection(nn.Left())
   276  			if err != nil {
   277  				return nil, false, err
   278  			}
   279  			if !ok {
   280  				return nil, false, nil
   281  			}
   282  			right, ok, err := getHighestProjection(nn.Right())
   283  			if err != nil {
   284  				return nil, false, err
   285  			}
   286  			if !ok {
   287  				return nil, false, nil
   288  			}
   289  			switch e := left.(type) {
   290  			case expression.Tuple:
   291  				proj = append(proj, e.Children()...)
   292  			default:
   293  				proj = append(proj, e)
   294  			}
   295  			switch e := right.(type) {
   296  			case expression.Tuple:
   297  				proj = append(proj, e.Children()...)
   298  			default:
   299  				proj = append(proj, e)
   300  			}
   301  		case *plan.GroupBy:
   302  			// todo(max): could make better effort to get column ids from these,
   303  			// but real fix is also giving synthesized projection column ids
   304  			// in binder
   305  			proj = nn.SelectedExprs
   306  		case *plan.Window:
   307  			proj = nn.SelectExprs
   308  		case *plan.SetOp:
   309  			return nil, false, nil
   310  		case plan.TableIdNode:
   311  			colset := nn.Columns()
   312  			idx := 0
   313  			sch := n.Schema()
   314  			for id, hasNext := colset.Next(1); hasNext; id, hasNext = colset.Next(id + 1) {
   315  				col := sch[idx]
   316  				proj = append(proj, expression.NewGetFieldWithTable(int(id), int(nn.Id()), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable))
   317  				idx++
   318  			}
   319  		default:
   320  			if len(nn.Children()) == 1 {
   321  				n = nn.Children()[0]
   322  				continue
   323  			}
   324  		}
   325  		if proj == nil {
   326  			break
   327  		}
   328  		projCopy := make([]sql.Expression, len(proj))
   329  		copy(projCopy, proj)
   330  		for i, p := range projCopy {
   331  			if a, ok := p.(*expression.Alias); ok {
   332  				if a.Unreferencable() || a.Id() == 0 {
   333  					return nil, false, nil
   334  				}
   335  				projCopy[i] = expression.NewGetField(int(a.Id()), a.Type(), a.Name(), a.IsNullable())
   336  			}
   337  		}
   338  		if len(projCopy) == 1 {
   339  			return projCopy[0], true, nil
   340  		}
   341  		return expression.NewTuple(projCopy...), true, nil
   342  	}
   343  	return nil, false, fmt.Errorf("failed to find decorrelation projection")
   344  }