github.com/dolthub/go-mysql-server@v0.18.0/sql/stats/join.go (about)

     1  package stats
     2  
     3  import (
     4  	"container/heap"
     5  	"fmt"
     6  	"log"
     7  	"math"
     8  	"time"
     9  
    10  	"github.com/pkg/errors"
    11  
    12  	"github.com/dolthub/go-mysql-server/sql"
    13  	"github.com/dolthub/go-mysql-server/sql/types"
    14  )
    15  
    16  var ErrJoinStringStatistics = errors.New("joining string histograms is unsupported")
    17  
    18  // Join performs an alignment algorithm on two sets of statistics, and
    19  // then pairwise estimates bucket cardinalities by joining most common
    20  // values (mcvs) directly and assuming key uniformity otherwise. Only
    21  // numeric types are supported.
    22  func Join(s1, s2 sql.Statistic, prefixCnt int, debug bool) (sql.Statistic, error) {
    23  	cmp := func(row1, row2 sql.Row) (int, error) {
    24  		var keyCmp int
    25  		for i := 0; i < prefixCnt; i++ {
    26  			k1, _, err := s1.Types()[i].Promote().Convert(row1[i])
    27  			if err != nil {
    28  				return 0, fmt.Errorf("incompatible types")
    29  			}
    30  
    31  			k2, _, err := s2.Types()[i].Promote().Convert(row2[i])
    32  			if err != nil {
    33  				return 0, fmt.Errorf("incompatible types")
    34  			}
    35  
    36  			cmp, err := s1.Types()[i].Promote().Compare(k1, k2)
    37  			if err != nil {
    38  				return 0, err
    39  			}
    40  			if cmp == 0 {
    41  				continue
    42  			}
    43  			keyCmp = cmp
    44  			break
    45  		}
    46  		return keyCmp, nil
    47  	}
    48  
    49  	s1Buckets, err := mergeOverlappingBuckets(s1.Histogram(), s1.Types())
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	s2Buckets, err := mergeOverlappingBuckets(s2.Histogram(), s2.Types())
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	s1AliHist, s2AliHist, err := AlignBuckets(s1Buckets, s2Buckets, s1.LowerBound(), s2.LowerBound(), s1.Types()[:prefixCnt], s2.Types()[:prefixCnt], cmp)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	if debug {
    63  		log.Println("left", s1AliHist.DebugString())
    64  		log.Println("right", s2AliHist.DebugString())
    65  	}
    66  
    67  	newHist, err := joinAlignedStats(s1AliHist, s2AliHist, cmp)
    68  	ret := NewStatistic(0, 0, 0, s1.AvgSize(), time.Now(), s1.Qualifier(), s1.Columns(), s1.Types(), newHist, s1.IndexClass(), nil)
    69  	return UpdateCounts(ret), nil
    70  }
    71  
    72  // joinAlignedStats assumes |left| and |right| have the same number of
    73  // buckets to estimate the join cardinality. Most common values (mcvs) adjust
    74  // the estimates to account for outlier keys that are a disproportionately
    75  // high fraction of the index.
    76  func joinAlignedStats(left, right sql.Histogram, cmp func(sql.Row, sql.Row) (int, error)) ([]*Bucket, error) {
    77  	var newBuckets []*Bucket
    78  	newCnt := uint64(0)
    79  	for i := range left {
    80  		l := left[i]
    81  		r := right[i]
    82  		lDistinct := float64(l.DistinctCount())
    83  		rDistinct := float64(r.DistinctCount())
    84  
    85  		lRows := float64(l.RowCount())
    86  		rRows := float64(r.RowCount())
    87  
    88  		var rows uint64
    89  
    90  		// mcvs counted in isolation
    91  		// todo: should we assume non-match MCVs in smaller set
    92  		// contribute MCV count * average frequency from the larger?
    93  		var mcvMatch int
    94  		for i, key1 := range l.Mcvs() {
    95  			for j, key2 := range r.Mcvs() {
    96  				v, err := cmp(key1, key2)
    97  				if err != nil {
    98  					return nil, err
    99  				}
   100  				if v == 0 {
   101  					rows += l.McvCounts()[i] * r.McvCounts()[j]
   102  					lRows -= float64(l.McvCounts()[i])
   103  					rRows -= float64(r.McvCounts()[j])
   104  					lDistinct--
   105  					rDistinct--
   106  					mcvMatch++
   107  					break
   108  				}
   109  			}
   110  		}
   111  
   112  		// true up negative approximations
   113  		lRows = math.Max(lRows, 0)
   114  		rRows = math.Max(rRows, 0)
   115  		lDistinct = math.Max(lDistinct, 0)
   116  		rDistinct = math.Max(rDistinct, 0)
   117  
   118  		// Selinger method on rest of buckets
   119  		maxDistinct := math.Max(lDistinct, rDistinct)
   120  		minDistinct := math.Min(lDistinct, rDistinct)
   121  
   122  		if maxDistinct > 0 {
   123  			rows += uint64(float64(lRows*rRows) / float64(maxDistinct))
   124  		}
   125  
   126  		newCnt += rows
   127  
   128  		// TODO: something smarter with MCVs
   129  		mcvs := append(l.Mcvs(), r.Mcvs()...)
   130  		mcvCounts := append(l.McvCounts(), r.McvCounts()...)
   131  
   132  		newBucket := NewHistogramBucket(
   133  			rows,
   134  			uint64(minDistinct)+uint64(mcvMatch), // matched mcvs contribute back to result distinct count
   135  			uint64(float64(l.NullCount()*r.NullCount())/float64(maxDistinct)),
   136  			l.BoundCount()*r.BoundCount(), l.UpperBound(), mcvCounts, mcvs)
   137  		newBuckets = append(newBuckets, newBucket)
   138  	}
   139  	return newBuckets, nil
   140  }
   141  
   142  // AlignBuckets produces two histograms with the same number of buckets.
   143  // Start by using upper bound keys to truncate histogram with a larger
   144  // keyspace. Then for every misaligned pair of buckets, cut the one with the
   145  // higher bound value on the smaller's key. We use a linear interpolation
   146  // to divide keys when splitting.
   147  func AlignBuckets(h1, h2 sql.Histogram, lBound1, lBound2 sql.Row, s1Types, s2Types []sql.Type, cmp func(sql.Row, sql.Row) (int, error)) (sql.Histogram, sql.Histogram, error) {
   148  	var numericTypes bool = true
   149  	for _, t := range s1Types {
   150  		if _, ok := t.(sql.NumberType); !ok {
   151  			numericTypes = false
   152  			break
   153  		}
   154  	}
   155  
   156  	if !numericTypes {
   157  		// todo(max): distance between two strings is difficult,
   158  		// but we could cut equal fractions depending on total
   159  		// cuts for a bucket
   160  		return nil, nil, ErrJoinStringStatistics
   161  	}
   162  
   163  	var leftRes sql.Histogram
   164  	var rightRes sql.Histogram
   165  	var leftStack []sql.HistogramBucket
   166  	var rightStack []sql.HistogramBucket
   167  	var nextL sql.HistogramBucket
   168  	var nextR sql.HistogramBucket
   169  	var keyCmp int
   170  	var err error
   171  	var reverse bool
   172  
   173  	swap := func() {
   174  		leftStack, rightStack = rightStack, leftStack
   175  		nextL, nextR = nextR, nextL
   176  		leftRes, rightRes = rightRes, leftRes
   177  		h1, h2 = h2, h1
   178  		reverse = !reverse
   179  	}
   180  
   181  	var state sjState = sjStateInit
   182  	for state != sjStateEOF {
   183  		switch state {
   184  		case sjStateInit:
   185  			// Merge adjacent overlapping buckets within each histogram.
   186  			// Truncate non-overlapping tail buckets between left and right.
   187  			// Reverse the buckets into stacks.
   188  
   189  			s1Hist, err := mergeOverlappingBuckets(h1, s1Types)
   190  			if err != nil {
   191  				return nil, nil, err
   192  			}
   193  			s2Hist, err := mergeOverlappingBuckets(h2, s2Types)
   194  			if err != nil {
   195  				return nil, nil, err
   196  			}
   197  
   198  			s1Last := s1Hist[len(s1Hist)-1].UpperBound()
   199  			s2Last := s2Hist[len(s2Hist)-1].UpperBound()
   200  			idx1, err := PrefixLtHist(s1Hist, s2Last, cmp)
   201  			if err != nil {
   202  				return nil, nil, err
   203  			}
   204  			idx2, err := PrefixLtHist(s2Hist, s1Last, cmp)
   205  			if err != nil {
   206  				return nil, nil, err
   207  			}
   208  			if idx1 < len(s1Hist) {
   209  				idx1++
   210  			}
   211  			if idx2 < len(s2Hist) {
   212  				idx2++
   213  			}
   214  			s1Hist = s1Hist[:idx1]
   215  			s2Hist = s2Hist[:idx2]
   216  
   217  			if lBound2 != nil {
   218  				idx, err := PrefixGteHist(s1Hist, lBound2, cmp)
   219  				if err != nil {
   220  					return nil, nil, err
   221  				}
   222  				s1Hist = s1Hist[idx:]
   223  			}
   224  			if lBound1 != nil {
   225  				idx, err := PrefixGteHist(s2Hist, lBound1, cmp)
   226  				if err != nil {
   227  					return nil, nil, err
   228  				}
   229  				s2Hist = s2Hist[idx:]
   230  			}
   231  
   232  			if len(s1Hist) == 0 || len(s2Hist) == 0 {
   233  				return nil, nil, nil
   234  			}
   235  
   236  			if len(s1Hist) == 0 || len(s2Hist) == 0 {
   237  				return nil, nil, nil
   238  			}
   239  
   240  			m := len(s1Hist) - 1
   241  			leftStack = make([]sql.HistogramBucket, m)
   242  			for i, b := range s1Hist {
   243  				if i == 0 {
   244  					nextL = b
   245  					continue
   246  				}
   247  				leftStack[m-i] = b
   248  			}
   249  
   250  			n := len(s2Hist) - 1
   251  			rightStack = make([]sql.HistogramBucket, n)
   252  			for i, b := range s2Hist {
   253  				if i == 0 {
   254  					nextR = b
   255  					continue
   256  				}
   257  				rightStack[n-i] = b
   258  			}
   259  
   260  			state = sjStateCmp
   261  
   262  		case sjStateCmp:
   263  			keyCmp, err = cmp(nextL.UpperBound(), nextR.UpperBound())
   264  			if err != nil {
   265  				return nil, nil, err
   266  			}
   267  			switch keyCmp {
   268  			case 0:
   269  				state = sjStateInc
   270  			case 1:
   271  				state = sjStateCutLeft
   272  			case -1:
   273  				state = sjStateCutRight
   274  			}
   275  
   276  		case sjStateCutLeft:
   277  			// default cuts left
   278  			state = sjStateCut
   279  
   280  		case sjStateCutRight:
   281  			// switch to make left the cut target
   282  			swap()
   283  			state = sjStateCut
   284  
   285  		case sjStateCut:
   286  			state = sjStateInc
   287  			// The left bucket is longer than the right bucket.
   288  			// In the default case, we will cut the left bucket on
   289  			// the right boundary, and put the right remainder back
   290  			// on the stack.
   291  
   292  			if len(leftRes) == 0 {
   293  				// It is difficult to cut the first bucket because the
   294  				// lower bound is negative infinity. We instead extend the
   295  				// smaller side (right) by stealing form its precedeccors
   296  				// up to the left cutpoint.
   297  
   298  				if len(rightStack) == 0 {
   299  					continue
   300  				}
   301  
   302  				var peekR sql.HistogramBucket
   303  				for len(rightStack) > 0 {
   304  					// several right buckets might be less than the left cutpoint
   305  					peekR = rightStack[len(rightStack)-1]
   306  					rightStack = rightStack[:len(rightStack)-1]
   307  					keyCmp, err = cmp(peekR.UpperBound(), nextL.UpperBound())
   308  					if err != nil {
   309  						return nil, nil, err
   310  					}
   311  					if keyCmp > 0 {
   312  						break
   313  					}
   314  
   315  					nextR = NewHistogramBucket(
   316  						uint64(float64(nextR.RowCount())+float64(peekR.RowCount())),
   317  						uint64(float64(nextR.DistinctCount())+float64(peekR.DistinctCount())),
   318  						uint64(float64(nextR.NullCount())+float64(peekR.NullCount())),
   319  						peekR.BoundCount(), peekR.UpperBound(), peekR.McvCounts(), peekR.Mcvs())
   320  				}
   321  
   322  				// nextR < nextL < peekR
   323  				bucketMagnitude, err := euclideanDistance(nextR.UpperBound(), peekR.UpperBound(), len(s1Types))
   324  				if err != nil {
   325  					return nil, nil, err
   326  				}
   327  
   328  				if bucketMagnitude == 0 {
   329  					peekR = nil
   330  					continue
   331  				}
   332  
   333  				// estimate midpoint
   334  				cutMagnitude, err := euclideanDistance(nextR.UpperBound(), nextL.UpperBound(), len(s1Types))
   335  				if err != nil {
   336  					return nil, nil, err
   337  				}
   338  
   339  				cutFrac := cutMagnitude / bucketMagnitude
   340  
   341  				// lastL -> nextR
   342  				firstHalf := NewHistogramBucket(
   343  					uint64(float64(nextR.RowCount())+float64(peekR.RowCount())*cutFrac),
   344  					uint64(float64(nextR.DistinctCount())+float64(peekR.DistinctCount())*cutFrac),
   345  					uint64(float64(nextR.NullCount())+float64(peekR.NullCount())*cutFrac),
   346  					1, nextL.UpperBound(), nil, nil)
   347  
   348  				// nextR -> nextL
   349  				secondHalf := NewHistogramBucket(
   350  					uint64(float64(peekR.RowCount())*(1-cutFrac)),
   351  					uint64(float64(peekR.DistinctCount())*(1-cutFrac)),
   352  					uint64(float64(peekR.NullCount())*(1-cutFrac)),
   353  					peekR.BoundCount(),
   354  					peekR.UpperBound(),
   355  					peekR.McvCounts(),
   356  					peekR.Mcvs())
   357  
   358  				nextR = firstHalf
   359  				rightStack = append(rightStack, secondHalf)
   360  				continue
   361  			}
   362  
   363  			// get left "distance"
   364  			bucketMagnitude, err := euclideanDistance(nextL.UpperBound(), leftRes[len(leftRes)-1].UpperBound(), len(s1Types))
   365  			if err != nil {
   366  				return nil, nil, err
   367  			}
   368  
   369  			// estimate midpoint
   370  			cutMagnitude, err := euclideanDistance(nextL.UpperBound(), nextR.UpperBound(), len(s1Types))
   371  			if err != nil {
   372  				return nil, nil, err
   373  			}
   374  
   375  			cutFrac := cutMagnitude / bucketMagnitude
   376  
   377  			// lastL -> nextR
   378  			firstHalf := NewHistogramBucket(
   379  				uint64(float64(nextL.RowCount())*(1-cutFrac)),
   380  				uint64(float64(nextL.DistinctCount())*(1-cutFrac)),
   381  				uint64(float64(nextL.NullCount())*(1-cutFrac)),
   382  				1, nextR.UpperBound(), nil, nil)
   383  
   384  			// nextR -> nextL
   385  			secondHalf := NewHistogramBucket(
   386  				uint64(float64(nextL.RowCount())*cutFrac),
   387  				uint64(float64(nextL.DistinctCount())*cutFrac),
   388  				uint64(float64(nextL.NullCount())*cutFrac),
   389  				nextL.BoundCount(),
   390  				nextL.UpperBound(),
   391  				nextL.McvCounts(),
   392  				nextL.Mcvs())
   393  
   394  			nextL = firstHalf
   395  			leftStack = append(leftStack, secondHalf)
   396  
   397  		case sjStateInc:
   398  			leftRes = append(leftRes, nextL)
   399  			rightRes = append(rightRes, nextR)
   400  
   401  			nextL = nil
   402  			nextR = nil
   403  
   404  			if len(leftStack) > 0 {
   405  				nextL = leftStack[len(leftStack)-1]
   406  				leftStack = leftStack[:len(leftStack)-1]
   407  			}
   408  			if len(rightStack) > 0 {
   409  				nextR = rightStack[len(rightStack)-1]
   410  				rightStack = rightStack[:len(rightStack)-1]
   411  			}
   412  
   413  			state = sjStateCmp
   414  
   415  			if nextL == nil || nextR == nil {
   416  				state = sjStateExhaust
   417  			}
   418  
   419  		case sjStateExhaust:
   420  			state = sjStateEOF
   421  
   422  			if nextL == nil && nextR == nil {
   423  				continue
   424  			}
   425  
   426  			if nextL == nil {
   427  				// swap so right side is nil
   428  				swap()
   429  			}
   430  
   431  			// squash the trailing buckets into one
   432  			// TODO: cut the left side on the right's final bound when there is >1 left
   433  			leftStack = append(leftStack, nextL)
   434  			nextL = leftRes[len(leftRes)-1]
   435  			leftRes = leftRes[:len(leftRes)-1]
   436  			for len(leftStack) > 0 {
   437  				peekL := leftStack[len(leftStack)-1]
   438  				leftStack = leftStack[:len(leftStack)-1]
   439  				nextL = NewHistogramBucket(
   440  					uint64(float64(nextL.RowCount())+float64(peekL.RowCount())),
   441  					uint64(float64(nextL.DistinctCount())+float64(peekL.DistinctCount())),
   442  					uint64(float64(nextL.NullCount())+float64(peekL.NullCount())),
   443  					peekL.BoundCount(), peekL.UpperBound(), peekL.McvCounts(), peekL.Mcvs())
   444  			}
   445  			leftRes = append(leftRes, nextL)
   446  			nextL = nil
   447  
   448  		}
   449  	}
   450  
   451  	if reverse {
   452  		leftRes, rightRes = rightRes, leftRes
   453  	}
   454  	return leftRes, rightRes, nil
   455  }
   456  
   457  // mergeMcvs combines two sets of most common values, merging the bound keys
   458  // with the same value and keeping the top k of the merge result.
   459  func mergeMcvs(mcvs1, mcvs2 []sql.Row, mcvCnts1, mcvCnts2 []uint64, cmp func(sql.Row, sql.Row) (int, error)) ([]sql.Row, []uint64, error) {
   460  	if len(mcvs1) < len(mcvs2) {
   461  		// mcvs2 is low
   462  		mcvs1, mcvs2 = mcvs2, mcvs1
   463  		mcvCnts1, mcvCnts2 = mcvCnts2, mcvCnts1
   464  	}
   465  	if len(mcvs2) == 0 {
   466  		return mcvs1, mcvCnts1, nil
   467  	}
   468  
   469  	ret := NewSqlHeap(len(mcvs2))
   470  	seen := make(map[int]bool)
   471  	for i, row1 := range mcvs1 {
   472  		matched := -1
   473  		for j, row2 := range mcvs2 {
   474  			c, err := cmp(row1, row2)
   475  			if err != nil {
   476  				return nil, nil, err
   477  			}
   478  			if c == 0 {
   479  				matched = j
   480  				break
   481  			}
   482  		}
   483  		if matched > 0 {
   484  			seen[matched] = true
   485  			heap.Push(ret, NewHeapRow(mcvs1[i], int(mcvCnts1[i]+mcvCnts2[matched])))
   486  		} else {
   487  			heap.Push(ret, NewHeapRow(mcvs1[i], int(mcvCnts1[i])))
   488  		}
   489  	}
   490  	for j := range mcvs2 {
   491  		if !seen[j] {
   492  			heap.Push(ret, NewHeapRow(mcvs2[j], int(mcvCnts2[j])))
   493  
   494  		}
   495  	}
   496  	return ret.Array(), ret.Counts(), nil
   497  }
   498  
   499  // mergeOverlappingBuckets folds bins with one element into the previous
   500  // bucket when the bound keys match.
   501  func mergeOverlappingBuckets(h sql.Histogram, types []sql.Type) (sql.Histogram, error) {
   502  	cmp := func(l, r sql.Row) (int, error) {
   503  		for i := 0; i < len(types); i++ {
   504  			cmp, err := types[i].Compare(l[i], r[i])
   505  			if err != nil {
   506  				return 0, err
   507  			}
   508  			switch cmp {
   509  			case 0:
   510  				continue
   511  			case -1:
   512  				return -1, nil
   513  			case 1:
   514  				return 1, nil
   515  			}
   516  		}
   517  		return 0, nil
   518  	}
   519  	// |k| is the write position, |i| is the compare position
   520  	// |k| <= |i|
   521  	i := 0
   522  	k := 0
   523  	for i < len(h) {
   524  		h[k] = h[i]
   525  		i++
   526  		if i >= len(h) {
   527  			k++
   528  			break
   529  		}
   530  		mcvs, mcvCnts, err := mergeMcvs(h[i].Mcvs(), h[i-1].Mcvs(), h[i].McvCounts(), h[i-1].McvCounts(), cmp)
   531  		if err != nil {
   532  			return nil, err
   533  		}
   534  		for ; i < len(h) && h[i].DistinctCount() == 1; i++ {
   535  			eq, err := cmp(h[k].UpperBound(), h[i].UpperBound())
   536  			if err != nil {
   537  				return nil, err
   538  			}
   539  			if eq != 0 {
   540  				break
   541  			}
   542  			h[k] = NewHistogramBucket(
   543  				h[k].RowCount()+h[i].RowCount(),
   544  				h[k].DistinctCount(),
   545  				h[k].NullCount()+h[i].NullCount(),
   546  				h[k].BoundCount()+h[i].BoundCount(),
   547  				h[k].UpperBound(),
   548  				mcvCnts,
   549  				mcvs)
   550  		}
   551  		k++
   552  	}
   553  	return h[:k], nil
   554  }
   555  
   556  type sjState int8
   557  
   558  const (
   559  	sjStateUnknown = iota
   560  	sjStateInit
   561  	sjStateCmp
   562  	sjStateCutLeft
   563  	sjStateCutRight
   564  	sjStateCut
   565  	sjStateInc
   566  	sjStateExhaust
   567  	sjStateEOF
   568  )
   569  
   570  // euclideanDistance is a vectorwise sum of squares distance between
   571  // two numeric types.
   572  func euclideanDistance(row1, row2 sql.Row, prefixLen int) (float64, error) {
   573  	var distSq float64
   574  	for i := 0; i < prefixLen; i++ {
   575  		v1, _, err := types.Float64.Convert(row1[i])
   576  		if err != nil {
   577  			return 0, err
   578  		}
   579  		v2, _, err := types.Float64.Convert(row2[i])
   580  		if err != nil {
   581  			return 0, err
   582  		}
   583  		f1 := v1.(float64)
   584  		f2 := v2.(float64)
   585  		distSq += f1*f1 - 2*f1*f2 + f2*f2
   586  	}
   587  	return math.Sqrt(distSq), nil
   588  }