
     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package memo
    17  import (
    18  	"fmt"
    19  	"math"
    21  	""
    22  	""
    24  	""
    25  )
    27  const (
    28  	// reference
    29  	cpuCostFactor             = 0.01
    30  	seqIOCostFactor           = 1
    31  	randIOCostFactor          = 1.3
    32  	memCostFactor             = 2
    33  	concatCostFactor          = 0.75
    34  	degeneratePenalty         = 2.0
    35  	optimisticJoinSel         = .10
    36  	biasFactor                = 1e5
    37  	defaultFilterSelectivity  = .75
    38  	perKeyCostReductionFactor = 0.5
    39  	defaultTableSize          = 100
    40  )
    42  func NewDefaultCoster() Coster {
    43  	return &coster{}
    44  }
    46  type coster struct{}
    48  var _ Coster = (*coster)(nil)
    50  func (c *coster) EstimateCost(ctx *sql.Context, n RelExpr, s sql.StatsProvider) (float64, error) {
    51  	return c.costRel(ctx, n, s)
    52  }
    54  // costRel returns the estimated compute cost for a given physical
    55  // operator. Two physical operators in the same expression group will have
    56  // the same input and output cardinalities, but different evaluation costs.
    57  func (c *coster) costRel(ctx *sql.Context, n RelExpr, s sql.StatsProvider) (float64, error) {
    58  	switch n := n.(type) {
    59  	case *Project:
    60  		return float64(n.Child.RelProps.GetStats().RowCount()) * cpuCostFactor, nil
    61  	case *Distinct:
    62  		return float64(n.Child.RelProps.GetStats().RowCount()) * (cpuCostFactor + .75*memCostFactor), nil
    63  	case *Filter:
    64  		return float64(n.Child.RelProps.GetStats().RowCount()) * cpuCostFactor * float64(len(n.Filters)), nil
    65  	case JoinRel:
    66  		jp := n.JoinPrivate()
    67  		lBest := math.Max(1, float64(jp.Left.RelProps.GetStats().RowCount()))
    68  		rBest := math.Max(1, float64(jp.Right.RelProps.GetStats().RowCount()))
    70  		// if a child is an index scan, the table scan will be more expensive
    71  		var err error
    72  		lTableScan := uint64(lBest)
    73  		rTableScan := uint64(rBest)
    75  		if iScan, ok := jp.Left.Best.(*IndexScan); ok {
    76  			lTableScan, err = s.RowCount(ctx, iScan.Table.Database().Name(), iScan.Table.Name())
    77  			if err != nil {
    78  				lTableScan = defaultTableSize
    79  			}
    80  		}
    81  		if iScan, ok := jp.Right.Best.(*IndexScan); ok {
    82  			rTableScan, err = s.RowCount(ctx, iScan.Table.Database().Name(), iScan.Table.Name())
    83  			if err != nil {
    84  				rTableScan = defaultTableSize
    85  			}
    86  		}
    88  		selfJoinCard := math.Max(1, float64(n.Group().RelProps.GetStats().RowCount()))
    90  		switch {
    91  		case jp.Op.IsInner():
    92  			// arbitrary +1 penalty, prefer lookup
    93  			return (lBest*rBest+1)*seqIOCostFactor + (lBest*rBest)*cpuCostFactor, nil
    94  		case jp.Op.IsDegenerate():
    95  			return ((lBest*rBest)*seqIOCostFactor + (lBest*rBest)*cpuCostFactor) * degeneratePenalty, nil
    96  		case jp.Op.IsHash():
    97  			// TODO hash has to load whole table into memory, really bad for big right sides
    98  			if jp.Op.IsPartial() {
    99  				cost := lBest * (rBest / 2.0) * (seqIOCostFactor + cpuCostFactor)
   100  				return cost * .5, nil
   101  			}
   102  			return lBest*(seqIOCostFactor+cpuCostFactor) + float64(rTableScan)*(seqIOCostFactor+memCostFactor) + selfJoinCard*cpuCostFactor, nil
   104  		case jp.Op.IsLateral():
   105  			return (lBest*rBest-1)*seqIOCostFactor + (lBest*rBest)*cpuCostFactor, nil
   107  		case jp.Op.IsMerge():
   108  			// TODO memory overhead when not injective
   109  			// TODO lose index scan benefits, need to read whole table
   111  			if !n.(*MergeJoin).Injective {
   112  				// Injective is guarenteed to never iterate over multiple rows in memory.
   113  				// Otherwise O(k) where k is the key with the highest number of matches.
   114  				// Each comparison reduces the expected number of collisions on the comparator.
   115  				// TODO: better cost estimate for memory overhead
   116  				mergeCmtAdjustment := math.Max(0, 4-float64(n.(*MergeJoin).CmpCnt))
   117  				selfJoinCard += mergeCmtAdjustment
   118  			}
   120  			// cost is full left scan + full rightScan plus compute/memory overhead
   121  			// for this merge filter's cardinality
   122  			// TODO: estimate memory overhead
   123  			return float64(lTableScan+rTableScan)*(seqIOCostFactor+cpuCostFactor) + cpuCostFactor*selfJoinCard, nil
   124  		case jp.Op.IsLookup():
   125  			// TODO added overhead for right lookups
   126  			switch n := n.(type) {
   127  			case *LookupJoin:
   128  				if !n.Injective {
   129  					// partial index completion is undesirable
   130  					// TODO don't do this whe we have stats
   131  					selfJoinCard = math.Max(0, selfJoinCard+float64(indexCoverageAdjustment(n.Lookup)))
   132  				}
   134  				// read the whole left table and randIO into table equivalent to
   135  				// this join's output cardinality estimate
   136  				return lBest*seqIOCostFactor + selfJoinCard*(randIOCostFactor+seqIOCostFactor), nil
   137  			case *ConcatJoin:
   138  				return c.costConcatJoin(ctx, n, s)
   139  			}
   140  		case jp.Op.IsRange():
   141  			expectedNumberOfOverlappingJoins := rBest * perKeyCostReductionFactor
   142  			return lBest * expectedNumberOfOverlappingJoins * (seqIOCostFactor), nil
   143  		case jp.Op.IsPartial():
   144  			return lBest*seqIOCostFactor + lBest*(rBest/2.0)*(seqIOCostFactor+cpuCostFactor), nil
   145  		case jp.Op.IsFullOuter():
   146  			return ((lBest*rBest-1)*seqIOCostFactor + (lBest*rBest)*cpuCostFactor) * degeneratePenalty, nil
   147  		case jp.Op.IsLeftOuter():
   148  			return (lBest*rBest-1)*seqIOCostFactor + (lBest*rBest)*cpuCostFactor, nil
   149  		default:
   150  		}
   151  		return 0, fmt.Errorf("unhandled join type: %T (%s)", n, jp.Op)
   152  	default:
   153  		panic(fmt.Sprintf("coster does not support type: %T", n))
   154  	}
   155  }
   157  // isInjectiveMerge determines whether either of a merge join's child indexes returns only unique values for the merge
   158  // comparator.
   159  func isInjectiveMerge(n *MergeJoin, leftCompareExprs, rightCompareExprs []sql.Expression) bool {
   160  	{
   161  		keyExprs, nullmask := keyExprsForIndexFromTupleComparison(n.Left.RelProps.tableNodes[0].Id(), n.InnerScan.Index.Cols(), leftCompareExprs, rightCompareExprs)
   162  		if isInjectiveLookup(n.InnerScan.Index, n.JoinBase, keyExprs, nullmask) {
   163  			return true
   164  		}
   165  	}
   166  	{
   167  		keyExprs, nullmask := keyExprsForIndexFromTupleComparison(n.Right.RelProps.tableNodes[0].Id(), n.OuterScan.Index.Cols(), leftCompareExprs, rightCompareExprs)
   168  		if isInjectiveLookup(n.OuterScan.Index, n.JoinBase, keyExprs, nullmask) {
   169  			return true
   170  		}
   171  	}
   172  	return false
   173  }
   175  func keyExprsForIndexFromTupleComparison(tabId sql.TableId, idxExprs []sql.ColumnId, leftExprs []sql.Expression, rightExprs []sql.Expression) ([]sql.Expression, []bool) {
   176  	var keyExprs []sql.Expression
   177  	var nullmask []bool
   178  	for _, col := range idxExprs {
   179  		key, nullable := keyForExprFromTupleComparison(col, tabId, leftExprs, rightExprs)
   180  		if key == nil {
   181  			break
   182  		}
   183  		keyExprs = append(keyExprs, key)
   184  		nullmask = append(nullmask, nullable)
   185  	}
   186  	if len(keyExprs) == 0 {
   187  		return nil, nil
   188  	}
   189  	return keyExprs, nullmask
   190  }
   192  // keyForExpr returns an equivalence or constant value to satisfy the
   193  // lookup index expression.
   194  func keyForExprFromTupleComparison(targetCol sql.ColumnId, tabId sql.TableId, leftExprs []sql.Expression, rightExprs []sql.Expression) (sql.Expression, bool) {
   195  	for i, leftExpr := range leftExprs {
   196  		rightExpr := rightExprs[i]
   198  		var key sql.Expression
   199  		if ref, ok := leftExpr.(*expression.GetField); ok && ref.Id() == targetCol {
   200  			key = rightExpr
   201  		} else if ref, ok := rightExpr.(*expression.GetField); ok && ref.Id() == targetCol {
   202  			key = leftExpr
   203  		} else {
   204  			continue
   205  		}
   206  		// expression key can be arbitrarily complex (or simple), but cannot
   207  		// reference the lookup table
   208  		if !exprReferencesTable(key, tabId) {
   209  			return key, false
   210  		}
   212  	}
   213  	return nil, false
   214  }
   216  // TODO need a way to map memo groups to table ids (or names if this doesn't work)
   217  func exprReferencesTable(e sql.Expression, tabId sql.TableId) bool {
   218  	return transform.InspectExpr(e, func(e sql.Expression) bool {
   219  		gf, _ := e.(*expression.GetField)
   220  		if gf != nil && gf.TableId() == tabId {
   221  			return true
   222  		}
   223  		return false
   224  	})
   225  }
   227  func (c *coster) costConcatJoin(_ *sql.Context, n *ConcatJoin, _ sql.StatsProvider) (float64, error) {
   228  	l := float64(n.Left.RelProps.GetStats().RowCount())
   229  	var sel float64
   230  	for _, l := range n.Concat {
   231  		lookup := l
   232  		sel += lookupJoinSelectivity(lookup, n.JoinBase)
   233  	}
   234  	return l*sel*concatCostFactor*(randIOCostFactor+cpuCostFactor) - float64(n.Right.RelProps.GetStats().RowCount())*seqIOCostFactor, nil
   235  }
   237  // lookupJoinSelectivity estimates the selectivity of a join condition with n lhs rows and m rhs rows.
   238  // A join with a selectivity of k will return k*(n*m) rows.
   239  // Special case: A join with a selectivity of 0 will return n rows.
   240  func lookupJoinSelectivity(l *IndexScan, joinBase *JoinBase) float64 {
   241  	if isInjectiveLookup(l.Index, joinBase, l.Table.Expressions(), l.Table.NullMask()) {
   242  		return 0
   243  	}
   244  	return math.Pow(perKeyCostReductionFactor, float64(len(l.Table.Expressions()))) * optimisticJoinSel
   245  }
   247  // isInjectiveLookup returns whether every lookup with the given key expressions is guarenteed to return
   248  // at most one row.
   249  func isInjectiveLookup(idx *Index, joinBase *JoinBase, keyExprs []sql.Expression, nullMask []bool) bool {
   250  	if !idx.SqlIdx().IsUnique() {
   251  		return false
   252  	}
   254  	joinFds := joinBase.Group().RelProps.FuncDeps()
   256  	var notNull sql.ColSet
   257  	var constCols sql.ColSet
   258  	for i, nullable := range nullMask {
   259  		cols, _, nullRej := getExprScalarProps(keyExprs[i])
   260  		onCols := joinFds.EquivalenceClosure(cols)
   261  		if !nullable {
   262  			if nullRej {
   263  				// columns with nulls will be filtered out
   264  				// TODO double-checking nullRejecting might be redundant
   265  				notNull = notNull.Union(onCols)
   266  			}
   267  		}
   268  		// from the perspective of the secondary table, lookup keys
   269  		// will be constant
   270  		constCols = constCols.Union(onCols)
   271  	}
   273  	fds := sql.NewLookupFDs(joinBase.Right.RelProps.FuncDeps(), idx.ColSet(), notNull, constCols, joinFds.Equiv())
   274  	return fds.HasMax1Row()
   275  }
   277  func NewInnerBiasedCoster() Coster {
   278  	return &innerBiasedCoster{coster: &coster{}}
   279  }
   281  type innerBiasedCoster struct {
   282  	*coster
   283  }
   285  func (c *innerBiasedCoster) EstimateCost(ctx *sql.Context, r RelExpr, s sql.StatsProvider) (float64, error) {
   286  	switch r.(type) {
   287  	case *InnerJoin:
   288  		return -biasFactor, nil
   289  	default:
   290  		return c.costRel(ctx, r, s)
   291  	}
   292  }
   294  func NewHashBiasedCoster() Coster {
   295  	return &hashBiasedCoster{coster: &coster{}}
   296  }
   298  type hashBiasedCoster struct {
   299  	*coster
   300  }
   302  func (c *hashBiasedCoster) EstimateCost(ctx *sql.Context, r RelExpr, s sql.StatsProvider) (float64, error) {
   303  	switch r.(type) {
   304  	case *HashJoin:
   305  		return -biasFactor, nil
   306  	default:
   307  		return c.costRel(ctx, r, s)
   308  	}
   309  }
   311  func NewLookupBiasedCoster() Coster {
   312  	return &lookupBiasedCoster{coster: &coster{}}
   313  }
   315  type lookupBiasedCoster struct {
   316  	*coster
   317  }
   319  func (c *lookupBiasedCoster) EstimateCost(ctx *sql.Context, r RelExpr, s sql.StatsProvider) (float64, error) {
   320  	switch r.(type) {
   321  	case *LookupJoin, *ConcatJoin:
   322  		return -biasFactor, nil
   323  	default:
   324  		return c.costRel(ctx, r, s)
   325  	}
   326  }
   328  func NewMergeBiasedCoster() Coster {
   329  	return &mergeBiasedCoster{coster: &coster{}}
   330  }
   332  type mergeBiasedCoster struct {
   333  	*coster
   334  }
   336  func (c *mergeBiasedCoster) EstimateCost(ctx *sql.Context, r RelExpr, s sql.StatsProvider) (float64, error) {
   337  	switch r.(type) {
   338  	case *MergeJoin:
   339  		return -biasFactor, nil
   340  	default:
   341  		return c.costRel(ctx, r, s)
   342  	}
   343  }
   345  type partialBiasedCoster struct {
   346  	*coster
   347  }
   349  func NewPartialBiasedCoster() Coster {
   350  	return &partialBiasedCoster{coster: &coster{}}
   351  }
   353  func (c *partialBiasedCoster) EstimateCost(ctx *sql.Context, r RelExpr, s sql.StatsProvider) (float64, error) {
   354  	switch r.(type) {
   355  	case *AntiJoin, *SemiJoin:
   356  		return -biasFactor, nil
   357  	default:
   358  		return c.costRel(ctx, r, s)
   359  	}
   360  }
   362  type rangeHeapBiasedCoster struct {
   363  	*coster
   364  }
   366  func NewRangeHeapBiasedCoster() Coster {
   367  	return &rangeHeapBiasedCoster{coster: &coster{}}
   368  }
   370  func (c *rangeHeapBiasedCoster) EstimateCost(ctx *sql.Context, r RelExpr, s sql.StatsProvider) (float64, error) {
   371  	switch r.(type) {
   372  	case *RangeHeapJoin:
   373  		return -biasFactor, nil
   374  	default:
   375  		return c.costRel(ctx, r, s)
   376  	}
   377  }