github.com/dolthub/go-mysql-server@v0.18.0/sql/memo/exec_builder.go (about)

     1  package memo
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/dolthub/go-mysql-server/sql"
     7  	"github.com/dolthub/go-mysql-server/sql/expression"
     8  	"github.com/dolthub/go-mysql-server/sql/plan"
     9  	"github.com/dolthub/go-mysql-server/sql/types"
    10  )
    11  
    12  type ExecBuilder struct{}
    13  
    14  func NewExecBuilder() *ExecBuilder {
    15  	return &ExecBuilder{}
    16  }
    17  
    18  func (b *ExecBuilder) buildRel(r RelExpr, children ...sql.Node) (sql.Node, error) {
    19  	n, err := buildRelExpr(b, r, children...)
    20  	if err != nil {
    21  		return nil, err
    22  	}
    23  
    24  	return b.buildDistinct(n, r.Distinct())
    25  }
    26  
    27  func (b *ExecBuilder) buildInnerJoin(j *InnerJoin, children ...sql.Node) (sql.Node, error) {
    28  	if len(j.Filter) == 0 {
    29  		return plan.NewCrossJoin(children[0], children[1]), nil
    30  	}
    31  	filters := b.buildFilterConjunction(j.Filter...)
    32  
    33  	return plan.NewInnerJoin(children[0], children[1], filters), nil
    34  }
    35  
    36  func (b *ExecBuilder) buildCrossJoin(j *CrossJoin, children ...sql.Node) (sql.Node, error) {
    37  	return plan.NewCrossJoin(children[0], children[1]), nil
    38  }
    39  
    40  func (b *ExecBuilder) buildLeftJoin(j *LeftJoin, children ...sql.Node) (sql.Node, error) {
    41  	filters := b.buildFilterConjunction(j.Filter...)
    42  	return plan.NewLeftOuterJoin(children[0], children[1], filters), nil
    43  }
    44  
    45  func (b *ExecBuilder) buildFullOuterJoin(j *FullOuterJoin, children ...sql.Node) (sql.Node, error) {
    46  	filters := b.buildFilterConjunction(j.Filter...)
    47  	return plan.NewFullOuterJoin(children[0], children[1], filters), nil
    48  }
    49  
    50  func (b *ExecBuilder) buildSemiJoin(j *SemiJoin, children ...sql.Node) (sql.Node, error) {
    51  	filters := b.buildFilterConjunction(j.Filter...)
    52  	left := children[0]
    53  	return plan.NewJoin(left, children[1], j.Op, filters), nil
    54  }
    55  
    56  func (b *ExecBuilder) buildAntiJoin(j *AntiJoin, children ...sql.Node) (sql.Node, error) {
    57  	filters := b.buildFilterConjunction(j.Filter...)
    58  	return plan.NewJoin(children[0], children[1], j.Op, filters), nil
    59  }
    60  
    61  func (b *ExecBuilder) buildLookupJoin(j *LookupJoin, children ...sql.Node) (sql.Node, error) {
    62  	left := children[0]
    63  	right, err := b.buildIndexScan(j.Lookup, children[1])
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	filters := b.buildFilterConjunction(j.Filter...)
    68  	return plan.NewJoin(left, right, j.Op, filters).WithScopeLen(j.g.m.scopeLen), nil
    69  }
    70  
    71  func (b *ExecBuilder) buildRangeHeap(sr *RangeHeap, children ...sql.Node) (ret sql.Node, err error) {
    72  	switch n := children[0].(type) {
    73  	case *plan.Distinct:
    74  		ret, err = b.buildRangeHeap(sr, n.Child)
    75  		ret = plan.NewDistinct(ret)
    76  	case *plan.OrderedDistinct:
    77  		ret, err = b.buildRangeHeap(sr, n.Child)
    78  		ret = plan.NewOrderedDistinct(ret)
    79  	case *plan.Filter:
    80  		ret, err = b.buildRangeHeap(sr, n.Child)
    81  		ret = plan.NewFilter(n.Expression, ret)
    82  	case *plan.Project:
    83  		ret, err = b.buildRangeHeap(sr, n.Child)
    84  		ret = plan.NewProject(n.Projections, ret)
    85  	case *plan.Limit:
    86  		ret, err = b.buildRangeHeap(sr, n.Child)
    87  		ret = plan.NewLimit(n.Limit, ret)
    88  	case *plan.Sort:
    89  		ret, err = b.buildRangeHeap(sr, n.Child)
    90  		ret = plan.NewSort(n.SortFields, ret)
    91  	default:
    92  		var childNode sql.Node
    93  		if sr.MinIndex != nil {
    94  			childNode, err = b.buildIndexScan(sr.MinIndex, children[0])
    95  		} else {
    96  			sortExpr := sr.MinExpr
    97  			if err != nil {
    98  				return nil, err
    99  			}
   100  			sf := []sql.SortField{{
   101  				Column:       sortExpr,
   102  				Order:        sql.Ascending,
   103  				NullOrdering: sql.NullsFirst,
   104  			}}
   105  			childNode = plan.NewSort(sf, n)
   106  		}
   107  
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  		ret, err = plan.NewRangeHeap(
   112  			childNode,
   113  			sr.ValueCol,
   114  			sr.MinColRef,
   115  			sr.MaxColRef,
   116  			sr.RangeClosedOnLowerBound,
   117  			sr.RangeClosedOnUpperBound)
   118  	}
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	return ret, nil
   123  }
   124  
   125  func (b *ExecBuilder) buildRangeHeapJoin(j *RangeHeapJoin, children ...sql.Node) (sql.Node, error) {
   126  	var left sql.Node
   127  	var err error
   128  	if j.RangeHeap.ValueIndex != nil {
   129  		left, err = b.buildIndexScan(j.RangeHeap.ValueIndex)
   130  		if err != nil {
   131  			return nil, err
   132  		}
   133  	} else {
   134  		sortExpr := j.RangeHeap.ValueExpr
   135  		sf := []sql.SortField{{
   136  			Column:       sortExpr,
   137  			Order:        sql.Ascending,
   138  			NullOrdering: sql.NullsFirst,
   139  		}}
   140  		left = plan.NewSort(sf, children[0])
   141  	}
   142  
   143  	right, err := b.buildRangeHeap(j.RangeHeap, children[1])
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	filters := b.buildFilterConjunction(j.Filter...)
   148  	return plan.NewJoin(left, right, j.Op, filters).WithScopeLen(j.g.m.scopeLen), nil
   149  }
   150  
   151  func (b *ExecBuilder) buildConcatJoin(j *ConcatJoin, children ...sql.Node) (sql.Node, error) {
   152  	var alias string
   153  	var name string
   154  	rightC := children[1]
   155  	switch n := rightC.(type) {
   156  	case *plan.TableAlias:
   157  		alias = n.Name()
   158  		name = n.Child.(sql.Nameable).Name()
   159  		rightC = n.Child
   160  	case *plan.ResolvedTable:
   161  		name = n.Name()
   162  	}
   163  
   164  	right, err := b.buildIndexScan(j.Concat[0], children[1])
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	for _, look := range j.Concat[1:] {
   169  		l, err := b.buildIndexScan(look, children[1])
   170  		if err != nil {
   171  			return nil, err
   172  		}
   173  		right = plan.NewTransformedNamedNode(plan.NewConcat(l, right), name)
   174  	}
   175  
   176  	if alias != "" {
   177  		// restore alias
   178  		right = plan.NewTableAlias(alias, right)
   179  	}
   180  
   181  	filters := b.buildFilterConjunction(j.Filter...)
   182  
   183  	return plan.NewJoin(children[0], right, j.Op, filters).WithScopeLen(j.g.m.scopeLen), nil
   184  }
   185  
   186  func (b *ExecBuilder) buildHashJoin(j *HashJoin, children ...sql.Node) (sql.Node, error) {
   187  	leftProbeFilters := make([]sql.Expression, len(j.LeftAttrs))
   188  	for i := range j.LeftAttrs {
   189  		leftProbeFilters[i] = j.LeftAttrs[i]
   190  	}
   191  	leftProbeKey := expression.Tuple(leftProbeFilters)
   192  
   193  	tmpScope := j.g.m.scope
   194  	if tmpScope != nil {
   195  		tmpScope = tmpScope.NewScopeNoJoin()
   196  	}
   197  
   198  	rightEntryFilters := make([]sql.Expression, len(j.RightAttrs))
   199  	for i := range j.RightAttrs {
   200  		rightEntryFilters[i] = j.RightAttrs[i]
   201  	}
   202  	rightEntryKey := expression.Tuple(rightEntryFilters)
   203  
   204  	filters := b.buildFilterConjunction(j.Filter...)
   205  
   206  	outer := plan.NewHashLookup(children[1], rightEntryKey, leftProbeKey, j.Op)
   207  	inner := children[0]
   208  	return plan.NewJoin(inner, outer, j.Op, filters).WithScopeLen(j.g.m.scopeLen), nil
   209  }
   210  
   211  func (b *ExecBuilder) buildIndexScan(i *IndexScan, children ...sql.Node) (sql.Node, error) {
   212  	// need keyExprs for whole range for every dimension
   213  
   214  	if len(children) == 0 {
   215  		if i.Alias != "" {
   216  			return plan.NewTableAlias(i.Alias, i.Table), nil
   217  		}
   218  		return i.Table, nil
   219  	}
   220  	var ret sql.Node
   221  	var err error
   222  	switch n := children[0].(type) {
   223  	case sql.TableNode:
   224  		if i.Alias != "" {
   225  			ret = plan.NewTableAlias(i.Alias, i.Table)
   226  		} else {
   227  			ret = i.Table
   228  		}
   229  	case *plan.TableAlias:
   230  		ret = plan.NewTableAlias(n.Name(), i.Table)
   231  	case *plan.IndexedTableAccess:
   232  		ret = i.Table
   233  	case *plan.Distinct:
   234  		ret, err = b.buildIndexScan(i, n.Child)
   235  		ret = plan.NewDistinct(ret)
   236  	case *plan.OrderedDistinct:
   237  		ret, err = b.buildIndexScan(i, n.Child)
   238  		ret = plan.NewOrderedDistinct(ret)
   239  	case *plan.Project:
   240  		ret, err = b.buildIndexScan(i, n.Child)
   241  		ret = plan.NewProject(n.Projections, ret)
   242  	case *plan.Filter:
   243  		ret, err = b.buildIndexScan(i, n.Child)
   244  		ret = plan.NewFilter(n.Expression, ret)
   245  	case *plan.Limit:
   246  		ret, err = b.buildIndexScan(i, n.Child)
   247  		ret = plan.NewLimit(n.Limit, ret)
   248  	case *plan.Sort:
   249  		ret, err = b.buildIndexScan(i, n.Child)
   250  		ret = plan.NewSort(n.SortFields, ret)
   251  	default:
   252  		return nil, fmt.Errorf("unexpected *indexScan child: %T", n)
   253  	}
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  	return ret, nil
   258  }
   259  
   260  func checkIndexTypeMismatch(idx sql.Index, rang sql.Range) bool {
   261  	for i, typ := range idx.ColumnExpressionTypes() {
   262  		if !types.Null.Equals(rang[i].Typ) && !typ.Type.Equals(rang[i].Typ) {
   263  			return true
   264  		}
   265  	}
   266  	return false
   267  }
   268  
   269  func (b *ExecBuilder) buildMergeJoin(j *MergeJoin, children ...sql.Node) (sql.Node, error) {
   270  	inner, err := b.buildIndexScan(j.InnerScan, children[0])
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	outer, err := b.buildIndexScan(j.OuterScan, children[1])
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	if j.SwapCmp {
   280  		switch cmp := j.Filter[0].(type) {
   281  		case *expression.Equals:
   282  			j.Filter[0] = expression.NewEquals(cmp.Right(), cmp.Left())
   283  		case *expression.LessThan:
   284  			j.Filter[0] = expression.NewGreaterThan(cmp.Right(), cmp.Left())
   285  		case *expression.LessThanOrEqual:
   286  			j.Filter[0] = expression.NewGreaterThanOrEqual(cmp.Right(), cmp.Left())
   287  		default:
   288  			return nil, fmt.Errorf("unexpected non-comparison condition in merge join, %T", cmp)
   289  		}
   290  	}
   291  	filters := b.buildFilterConjunction(j.Filter...)
   292  	return plan.NewJoin(inner, outer, j.Op, filters).WithScopeLen(j.g.m.scopeLen), nil
   293  }
   294  
   295  func (b *ExecBuilder) buildLateralJoin(j *LateralJoin, children ...sql.Node) (sql.Node, error) {
   296  	if len(j.Filter) == 0 {
   297  		return plan.NewCrossJoin(children[0], children[1]), nil
   298  	}
   299  	filters := b.buildFilterConjunction(j.Filter...)
   300  	return plan.NewJoin(children[0], children[1], j.Op.AsLateral(), filters), nil
   301  }
   302  
   303  func (b *ExecBuilder) buildSubqueryAlias(r *SubqueryAlias, children ...sql.Node) (sql.Node, error) {
   304  	return r.Table, nil
   305  }
   306  
   307  func (b *ExecBuilder) buildMax1Row(r *Max1Row, children ...sql.Node) (sql.Node, error) {
   308  	return plan.NewMax1Row(children[0], ""), nil
   309  }
   310  
   311  func (b *ExecBuilder) buildTableFunc(r *TableFunc, children ...sql.Node) (sql.Node, error) {
   312  	return r.Table, nil
   313  }
   314  
   315  func (b *ExecBuilder) buildRecursiveCte(r *RecursiveCte, children ...sql.Node) (sql.Node, error) {
   316  	return r.Table, nil
   317  }
   318  
   319  func (b *ExecBuilder) buildValues(r *Values, _ ...sql.Node) (sql.Node, error) {
   320  	return r.Table, nil
   321  }
   322  
   323  func (b *ExecBuilder) buildRecursiveTable(r *RecursiveTable, _ ...sql.Node) (sql.Node, error) {
   324  	return r.Table, nil
   325  }
   326  
   327  func (b *ExecBuilder) buildJSONTable(n *JSONTable, _ ...sql.Node) (sql.Node, error) {
   328  	return n.Table, nil
   329  }
   330  
   331  func (b *ExecBuilder) buildTableAlias(r *TableAlias, _ ...sql.Node) (sql.Node, error) {
   332  	return r.Table, nil
   333  }
   334  
   335  func (b *ExecBuilder) buildTableScan(r *TableScan, _ ...sql.Node) (sql.Node, error) {
   336  	return r.Table, nil
   337  }
   338  
   339  func (b *ExecBuilder) buildEmptyTable(r *EmptyTable, _ ...sql.Node) (sql.Node, error) {
   340  	return r.Table, nil
   341  }
   342  
   343  func (b *ExecBuilder) buildSetOp(r *SetOp, _ ...sql.Node) (sql.Node, error) {
   344  	return r.Table, nil
   345  }
   346  
   347  func (b *ExecBuilder) buildProject(r *Project, children ...sql.Node) (sql.Node, error) {
   348  	proj := make([]sql.Expression, len(r.Projections))
   349  	for i := range r.Projections {
   350  		proj[i] = r.Projections[i]
   351  	}
   352  	return plan.NewProject(proj, children[0]), nil
   353  }
   354  
   355  func (b *ExecBuilder) buildFilter(r *Filter, children ...sql.Node) (sql.Node, error) {
   356  	ret := plan.NewFilter(expression.JoinAnd(r.Filters...), children[0])
   357  	return ret, nil
   358  }
   359  
   360  func (b *ExecBuilder) buildDistinct(n sql.Node, d distinctOp) (sql.Node, error) {
   361  	switch d {
   362  	case HashDistinctOp:
   363  		return plan.NewDistinct(n), nil
   364  	case SortedDistinctOp:
   365  		return plan.NewOrderedDistinct(n), nil
   366  	case NoDistinctOp:
   367  		return n, nil
   368  	default:
   369  		return nil, fmt.Errorf("unexpected distinct operator: %d", d)
   370  	}
   371  }
   372  
   373  func (b *ExecBuilder) buildFilterConjunction(filters ...sql.Expression) sql.Expression {
   374  	if len(filters) == 0 {
   375  		return expression.NewLiteral(true, types.Boolean)
   376  	}
   377  	return expression.JoinAnd(filters...)
   378  }