github.com/dolthub/go-mysql-server@v0.18.0/sql/memo/expr_group.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memo
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/plan"
    23  )
    24  
    25  // ExprGroup is a linked list of plans that return the same result set
    26  // defined by row count and schema.
    27  type ExprGroup struct {
    28  	m         *Memo
    29  	_children []*ExprGroup
    30  	RelProps  *relProps
    31  	First     RelExpr
    32  	Best      RelExpr
    33  
    34  	Id GroupId
    35  
    36  	Cost   float64
    37  	Done   bool
    38  	HintOk bool
    39  }
    40  
    41  func newExprGroup(m *Memo, id GroupId, expr exprType) *ExprGroup {
    42  	// bit of circularity: |grp| references |rel|, |rel| references |grp|,
    43  	// and |relProps| references |rel| and |grp| info.
    44  	grp := &ExprGroup{
    45  		m:  m,
    46  		Id: id,
    47  	}
    48  	expr.SetGroup(grp)
    49  	switch e := expr.(type) {
    50  	case RelExpr:
    51  		grp.First = e
    52  		grp.RelProps = newRelProps(e)
    53  	}
    54  	return grp
    55  }
    56  
    57  // Prepend adds a new plan to an expression group at the beginning of
    58  // the list, to avoid recursive exploration steps (like adding indexed joins).
    59  func (e *ExprGroup) Prepend(rel RelExpr) {
    60  	first := e.First
    61  	e.First = rel
    62  	rel.SetNext(first)
    63  }
    64  
    65  // children returns a unioned list of child ExprGroup for
    66  // every logical plan in this group.
    67  func (e *ExprGroup) children() []*ExprGroup {
    68  	relExpr, ok := e.First.(RelExpr)
    69  	if !ok {
    70  		return e.children()
    71  	}
    72  	n := relExpr
    73  	children := make([]*ExprGroup, 0)
    74  	for n != nil {
    75  		children = append(children, n.Children()...)
    76  		n = n.Next()
    77  	}
    78  	return children
    79  }
    80  
    81  // updateBest updates a group's Best to the given expression or a hinted
    82  // operator if the hinted plan is not found. Join operator is applied as
    83  // a local rather than global property.
    84  func (e *ExprGroup) updateBest(n RelExpr, grpCost float64) {
    85  	if e.Best == nil || grpCost < e.Cost {
    86  		e.Best = n
    87  		e.Cost = grpCost
    88  	}
    89  }
    90  
    91  func (e *ExprGroup) finalize(node sql.Node) (sql.Node, error) {
    92  	props := e.RelProps
    93  	var result = node
    94  	if props.sort != nil {
    95  		result = plan.NewSort(props.sort, result)
    96  	}
    97  	if props.Limit != nil {
    98  		result = plan.NewLimit(props.Limit, result)
    99  	}
   100  	return result, nil
   101  }
   102  
   103  // fixConflicts edits the children of a new best plan to account
   104  // for implementation correctness, like conflicting table lookups
   105  // and sorting. For example, a merge join with a filter child that
   106  // could alternatively be implemented as an indexScan should reject
   107  // the static indexScan to maintain the merge join's correctness.
   108  func (e *ExprGroup) fixConflicts() {
   109  	switch n := e.Best.(type) {
   110  	case *MergeJoin:
   111  		// todo: we should permit conflicting static indexScans with same index IDs
   112  		n.Left.findIndexScanConflict()
   113  		n.Right.findIndexScanConflict()
   114  	case *LookupJoin:
   115  		// LOOKUP_JOIN is more performant than INNER_JOIN with static indexScan
   116  		n.Right.findIndexScanConflict()
   117  	}
   118  
   119  	for _, g := range e.Best.Children() {
   120  		g.fixConflicts()
   121  	}
   122  }
   123  
   124  // findIndexScanConflict prevents indexScans from replacing filter nodes
   125  // for certain query plans that require different indexes or use indexes
   126  // in a special way.
   127  func (e *ExprGroup) findIndexScanConflict() {
   128  	e.fixTableScanPath()
   129  }
   130  
   131  // fixTableScanPath updates the intermediate group's |best| plan to
   132  // the path leading to a tableScan leaf.
   133  func (e *ExprGroup) fixTableScanPath() bool {
   134  	n := e.First
   135  	for n != nil {
   136  		src, ok := n.(SourceRel)
   137  		if !ok {
   138  			// not a source, try to find path through children
   139  			for _, c := range n.Children() {
   140  				if c.fixTableScanPath() {
   141  					// found path, update best
   142  					e.Best = n
   143  					n.SetDistinct(NoDistinctOp)
   144  					e.Done = true
   145  					return true
   146  				}
   147  			}
   148  			n = n.Next()
   149  			continue
   150  		}
   151  		_, ok = src.(*IndexScan)
   152  		if ok {
   153  			n = n.Next()
   154  			continue
   155  		}
   156  		// is a source, not an indexScan
   157  		n.SetDistinct(NoDistinctOp)
   158  		e.Best = n
   159  		e.HintOk = true
   160  		e.Done = true
   161  		return true
   162  	}
   163  	return false
   164  }
   165  
   166  func (e *ExprGroup) String() string {
   167  	b := strings.Builder{}
   168  	n := e.First
   169  	sep := ""
   170  	for n != nil {
   171  		b.WriteString(sep)
   172  		b.WriteString(fmt.Sprintf("(%s", FormatExpr(n)))
   173  		if e.Best != nil {
   174  			cost := n.Cost()
   175  			if cost == 0 {
   176  				// if source relation we want the cardinality
   177  				cost = float64(n.Group().RelProps.GetStats().RowCount())
   178  			}
   179  			b.WriteString(fmt.Sprintf(" %.1f", n.Cost()))
   180  
   181  			childCost := 0.0
   182  			for _, c := range n.Children() {
   183  				childCost += c.Cost
   184  			}
   185  			if e.Cost == n.Cost()+childCost {
   186  				b.WriteString(")*")
   187  			} else {
   188  				b.WriteString(")")
   189  			}
   190  		} else {
   191  			b.WriteString(")")
   192  		}
   193  		sep = " "
   194  		n = n.Next()
   195  	}
   196  	return b.String()
   197  }