github.com/dolthub/go-mysql-server@v0.18.0/enginetest/histogram_test.go (about)

     1  package enginetest
     2  
     3  import (
     4  	"container/heap"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"sort"
    11  	"testing"
    12  
    13  	"github.com/stretchr/testify/require"
    14  	"golang.org/x/exp/rand"
    15  
    16  	"github.com/dolthub/go-mysql-server/memory"
    17  	"github.com/dolthub/go-mysql-server/sql"
    18  	"github.com/dolthub/go-mysql-server/sql/expression"
    19  	"github.com/dolthub/go-mysql-server/sql/plan"
    20  	"github.com/dolthub/go-mysql-server/sql/rowexec"
    21  	"github.com/dolthub/go-mysql-server/sql/stats"
    22  	"github.com/dolthub/go-mysql-server/sql/types"
    23  )
    24  
    25  func init() {
    26  	rand.Seed(0)
    27  }
    28  
    29  type statsTest struct {
    30  	name       string
    31  	tableGen   func(*sql.Context, *memory.Database, int, sql.TableId, sql.ColumnId, ...interface{}) *plan.ResolvedTable
    32  	args1      []interface{}
    33  	args2      []interface{}
    34  	leftOrd    []int
    35  	rightOrd   []int
    36  	leftTypes  []sql.Type
    37  	rightTypes []sql.Type
    38  	err        float64
    39  }
    40  
    41  func TestNormDist(t *testing.T) {
    42  	tests := []struct {
    43  		name  string
    44  		mean1 float64
    45  		std1  float64
    46  		mean2 float64
    47  		std2  float64
    48  	}{
    49  		{
    50  			name:  "same table",
    51  			mean1: 0,
    52  			std1:  10,
    53  			mean2: 0,
    54  			std2:  10,
    55  		},
    56  		{
    57  			name:  "same mean, different std",
    58  			mean1: 0,
    59  			std1:  10,
    60  			mean2: 0,
    61  			std2:  2,
    62  		},
    63  		//
    64  		{
    65  			name:  "similar mean, different std1",
    66  			mean1: 1,
    67  			std1:  10,
    68  			mean2: 0,
    69  			std2:  2,
    70  		},
    71  		{
    72  			name:  "same mean, different std2",
    73  			mean1: 0,
    74  			std1:  8,
    75  			mean2: 0,
    76  			std2:  2,
    77  		},
    78  		{
    79  			name:  "same mean, different std3",
    80  			mean1: 0,
    81  			std1:  7,
    82  			mean2: 0,
    83  			std2:  2,
    84  		},
    85  		{
    86  			name:  "similar mean, different std4",
    87  			mean1: 1,
    88  			std1:  7,
    89  			mean2: 0,
    90  			std2:  2,
    91  		},
    92  		{
    93  			name:  "similar mean, different std5",
    94  			mean1: 2,
    95  			std1:  7,
    96  			mean2: 0,
    97  			std2:  2,
    98  		},
    99  		{
   100  			name:  "similar mean, different std6",
   101  			mean1: 3,
   102  			std1:  7,
   103  			mean2: 0,
   104  			std2:  2,
   105  		},
   106  		{
   107  			name:  "similar mean, different std7",
   108  			mean1: 4,
   109  			std1:  7,
   110  			mean2: 0,
   111  			std2:  2,
   112  		},
   113  		{
   114  			name:  "similar mean, different std8",
   115  			mean1: 4,
   116  			std1:  7,
   117  			mean2: 0,
   118  			std2:  3,
   119  		},
   120  		{
   121  			name:  "similar mean, different std9",
   122  			mean1: 5,
   123  			std1:  7,
   124  			mean2: 0,
   125  			std2:  3,
   126  		},
   127  	}
   128  
   129  	var statTests []statsTest
   130  	for _, t := range tests {
   131  		st := statsTest{
   132  			name: t.name,
   133  			tableGen: func(ctx *sql.Context, db *memory.Database, cnt int, tab sql.TableId, col sql.ColumnId, args ...interface{}) *plan.ResolvedTable {
   134  				mean := args[0].(float64)
   135  				std := args[1].(float64)
   136  				xyz := makeTable(db, fmt.Sprintf("xyz%d", tab), tab, col)
   137  				err := normalDistForTable(ctx, xyz, cnt, mean, std)
   138  				if err != nil {
   139  					panic(err)
   140  				}
   141  				return xyz
   142  			},
   143  			leftOrd:    []int{1},
   144  			rightOrd:   []int{1},
   145  			leftTypes:  []sql.Type{types.Int64, types.Int64, types.Int64},
   146  			rightTypes: []sql.Type{types.Int64, types.Int64, types.Int64},
   147  			err:        1,
   148  			args1:      []interface{}{t.mean1, t.std1},
   149  			args2:      []interface{}{t.mean2, t.std2},
   150  		}
   151  		statTests = append(statTests, st)
   152  	}
   153  
   154  	debug := false
   155  	runStatsSuite(t, statTests, 100, 5, debug)
   156  	runStatsSuite(t, statTests, 100, 10, debug)
   157  	runStatsSuite(t, statTests, 100, 20, debug)
   158  	runStatsSuite(t, statTests, 500, 10, debug)
   159  	runStatsSuite(t, statTests, 500, 20, debug)
   160  }
   161  
   162  func TestExpDist(t *testing.T) {
   163  	tests := []struct {
   164  		name    string
   165  		lambda1 float64
   166  		lambda2 float64
   167  	}{
   168  		{
   169  			name:    "same table",
   170  			lambda1: 1,
   171  			lambda2: 1,
   172  		},
   173  		{
   174  			name:    ".5/1.5",
   175  			lambda1: .5,
   176  			lambda2: 1.5,
   177  		},
   178  		{
   179  			name:    ".25/3",
   180  			lambda1: .25,
   181  			lambda2: 3,
   182  		},
   183  	}
   184  
   185  	var statTests []statsTest
   186  	for _, tt := range tests {
   187  		st := statsTest{
   188  			name: tt.name,
   189  			tableGen: func(ctx *sql.Context, db *memory.Database, cnt int, tab sql.TableId, col sql.ColumnId, args ...interface{}) *plan.ResolvedTable {
   190  				xyz := makeTable(db, "xyz", tab, col)
   191  				err := expDistForTable(ctx, xyz, cnt, args[0].(float64))
   192  				if err != nil {
   193  					panic(err)
   194  				}
   195  				return xyz
   196  			},
   197  			leftOrd:    []int{1},
   198  			rightOrd:   []int{1},
   199  			leftTypes:  []sql.Type{types.Int64, types.Int64, types.Int64},
   200  			rightTypes: []sql.Type{types.Int64, types.Int64, types.Int64},
   201  			args1:      []interface{}{tt.lambda1},
   202  			args2:      []interface{}{tt.lambda2},
   203  			err:        1,
   204  		}
   205  		statTests = append(statTests, st)
   206  	}
   207  
   208  	debug := false
   209  	runStatsSuite(t, statTests, 100, 5, debug)
   210  	runStatsSuite(t, statTests, 100, 10, debug)
   211  	runStatsSuite(t, statTests, 100, 20, debug)
   212  	runStatsSuite(t, statTests, 500, 10, debug)
   213  	runStatsSuite(t, statTests, 500, 20, debug)
   214  }
   215  
   216  func TestMultiDist(t *testing.T) {
   217  	tests := []statsTest{
   218  		{
   219  			name: "uniform dist int",
   220  			tableGen: func(ctx *sql.Context, db *memory.Database, cnt int, tab sql.TableId, col sql.ColumnId, args ...interface{}) *plan.ResolvedTable {
   221  				xyz := makeTable(db, "xyz", tab, col)
   222  				err := uniformDistForTable(ctx, xyz, cnt)
   223  				if err != nil {
   224  					panic(err)
   225  				}
   226  				return xyz
   227  			},
   228  			leftOrd:    []int{1},
   229  			rightOrd:   []int{1},
   230  			leftTypes:  []sql.Type{types.Int64, types.Int64, types.Int64},
   231  			rightTypes: []sql.Type{types.Int64, types.Int64, types.Int64},
   232  			err:        .1,
   233  		},
   234  	}
   235  
   236  	runStatsSuite(t, tests, 1000, 10, false)
   237  }
   238  
   239  // runStatsSuite will parse each statsTest and (1) generate 2 tables for a
   240  // join, (2) compute histograms for the tables on the join index, (3) use
   241  // the stats join algo to simulate a join estimate, and (4) compare the
   242  // estimate to the actual result set count.
   243  func runStatsSuite(t *testing.T, tests []statsTest, rowCnt, bucketCnt int, debug bool) {
   244  	for i, tt := range tests {
   245  		t.Run(fmt.Sprintf("%s: , rows: %d, buckets: %d", tt.name, rowCnt, bucketCnt), func(t *testing.T) {
   246  			db := memory.NewDatabase(fmt.Sprintf("test%d", i))
   247  			pro := memory.NewDBProvider(db)
   248  
   249  			xyz := tt.tableGen(newContext(pro), db, rowCnt, sql.TableId(i*2), 1, tt.args1...)
   250  			wuv := tt.tableGen(newContext(pro), db, rowCnt, sql.TableId(i*2+1), sql.ColumnId(len(tt.leftTypes)+1), tt.args2...)
   251  
   252  			// join the histograms on a particular field
   253  			var joinOps []sql.Expression
   254  			for i, l := range tt.leftOrd {
   255  				r := tt.rightOrd[i]
   256  				joinOps = append(joinOps, eq(l, r+len(tt.leftTypes)))
   257  			}
   258  
   259  			exp, err := expectedResultSize(newContext(pro), xyz, wuv, joinOps, debug)
   260  			require.NoError(t, err)
   261  
   262  			// get histograms for tables
   263  			xHist, err := testHistogram(newContext(pro), xyz, tt.leftOrd, bucketCnt)
   264  			require.NoError(t, err)
   265  
   266  			wHist, err := testHistogram(newContext(pro), wuv, tt.rightOrd, bucketCnt)
   267  			require.NoError(t, err)
   268  
   269  			if debug {
   270  				log.Printf("xyz:\n%s\n", sql.Histogram(xHist).DebugString())
   271  				log.Printf("wuv:\n%s\n", sql.Histogram(wHist).DebugString())
   272  			}
   273  
   274  			lStat := &stats.Statistic{Typs: []sql.Type{types.Int64}}
   275  			for _, b := range xHist {
   276  				lStat.Hist = append(lStat.Hist, b.(*stats.Bucket))
   277  			}
   278  			rStat := &stats.Statistic{Typs: []sql.Type{types.Int64}}
   279  			for _, b := range wHist {
   280  				rStat.Hist = append(rStat.Hist, b.(*stats.Bucket))
   281  			}
   282  
   283  			res, err := stats.Join(stats.UpdateCounts(lStat), stats.UpdateCounts(rStat), 1, debug)
   284  			require.NoError(t, err)
   285  			if debug {
   286  				log.Printf("join %s\n", res.Histogram().DebugString())
   287  			}
   288  
   289  			denom := float64(exp)
   290  			if cmp := float64(res.RowCount()); cmp > denom {
   291  				denom = cmp
   292  			}
   293  			delta := float64(exp-int(res.RowCount())) / denom
   294  			if delta < 0 {
   295  				delta = -delta
   296  			}
   297  			if debug {
   298  				log.Println(res.RowCount(), exp, delta)
   299  			}
   300  
   301  			// This compares the error percentage for our estimate to an
   302  			// error threshold specified in the statTest. The error bounds
   303  			// are loose and mostly useful for debugging at this point.
   304  			require.Less(t, delta, tt.err, "%d/%d/%.2f\nleft %s\nright %s", res.RowCount(), exp, delta, sql.Histogram(xHist).DebugString(), sql.Histogram(wHist).DebugString())
   305  		})
   306  	}
   307  }
   308  
   309  func testHistogram(ctx *sql.Context, table *plan.ResolvedTable, fields []int, buckets int) ([]sql.HistogramBucket, error) {
   310  	var cnt uint64
   311  	if st, ok := table.UnderlyingTable().(sql.StatisticsTable); ok {
   312  		var err error
   313  		cnt, _, err = st.RowCount(ctx)
   314  		if err != nil {
   315  			return nil, err
   316  		}
   317  	}
   318  	if cnt == 0 {
   319  		return nil, fmt.Errorf("found zero row count for table")
   320  	}
   321  
   322  	i, err := rowexec.DefaultBuilder.Build(ctx, table, nil)
   323  	rows, err := sql.RowIterToRows(ctx, i)
   324  	if err != nil {
   325  		return nil, err
   326  	}
   327  
   328  	sch := table.Schema()
   329  
   330  	keyVals := make([]sql.Row, len(rows))
   331  	for i, row := range rows {
   332  		for _, ord := range fields {
   333  			keyVals[i] = append(keyVals[i], row[ord])
   334  		}
   335  	}
   336  
   337  	cmp := func(i, j int) int {
   338  		k := 0
   339  		for k < len(fields) && keyVals[i][k] == keyVals[j][k] {
   340  			k++
   341  		}
   342  		if k == len(fields) {
   343  			return 0
   344  		}
   345  		col := sch[fields[k]]
   346  		cmp, _ := col.Type.Compare(keyVals[i][k], keyVals[j][k])
   347  		return cmp
   348  	}
   349  
   350  	sort.Slice(keyVals, func(i, j int) bool { return cmp(i, j) <= 0 })
   351  
   352  	var types []sql.Type
   353  	for _, i := range fields {
   354  		types = append(types, sch[i].Type)
   355  	}
   356  
   357  	var histogram []sql.HistogramBucket
   358  	rowsPerBucket := int(cnt) / buckets
   359  	currentBucket := &stats.Bucket{DistinctCnt: 1}
   360  	mcvCnt := 3
   361  	currentCnt := 0
   362  	mcvs := stats.NewSqlHeap(mcvCnt)
   363  	for i, row := range keyVals {
   364  		currentCnt++
   365  		currentBucket.RowCnt++
   366  		if i > 0 {
   367  			if cmp(i, i-1) != 0 {
   368  				currentBucket.DistinctCnt++
   369  				heap.Push(mcvs, stats.NewHeapRow(keyVals[i-1], currentCnt))
   370  				currentCnt = 1
   371  			}
   372  		}
   373  		if currentBucket.RowCnt > uint64(rowsPerBucket) {
   374  			currentBucket.BoundVal = row
   375  			currentBucket.BoundCnt = uint64(currentCnt)
   376  			heap.Push(mcvs, stats.NewHeapRow(row, currentCnt))
   377  			currentBucket.McvVals = mcvs.Array()
   378  			currentBucket.McvsCnt = mcvs.Counts()
   379  			histogram = append(histogram, currentBucket)
   380  			currentBucket = &stats.Bucket{DistinctCnt: 1}
   381  			mcvs = stats.NewSqlHeap(mcvCnt)
   382  			currentCnt = 0
   383  		}
   384  	}
   385  	currentBucket.BoundVal = keyVals[len(keyVals)-1]
   386  	currentBucket.BoundCnt = uint64(currentCnt)
   387  	currentBucket.McvVals = mcvs.Array()
   388  	currentBucket.McvsCnt = mcvs.Counts()
   389  	histogram = append(histogram, currentBucket)
   390  	return histogram, nil
   391  }
   392  
   393  func eq(idx1, idx2 int) *expression.Equals {
   394  	return expression.NewEquals(
   395  		expression.NewGetField(idx1, types.Int64, "", false),
   396  		expression.NewGetField(idx2, types.Int64, "", false))
   397  }
   398  
   399  func childSchema(source string) sql.PrimaryKeySchema {
   400  	return sql.NewPrimaryKeySchema(sql.Schema{
   401  		{Name: "x", Source: source, Type: types.Int64, Nullable: false},
   402  		{Name: "y", Source: source, Type: types.Int64, Nullable: true},
   403  		{Name: "z", Source: source, Type: types.Int64, Nullable: true},
   404  	}, 0)
   405  }
   406  
   407  func makeTable(db *memory.Database, name string, tabId sql.TableId, colId sql.ColumnId) *plan.ResolvedTable {
   408  	t := memory.NewTable(db, name, childSchema(name), nil)
   409  	t.EnablePrimaryKeyIndexes()
   410  	colset := sql.NewColSet(sql.ColumnId(colId), sql.ColumnId(colId+1), sql.ColumnId(colId+2))
   411  	return plan.NewResolvedTable(t, db, nil).WithId(sql.TableId(tabId)).WithColumns(colset).(*plan.ResolvedTable)
   412  }
   413  
   414  func newContext(provider *memory.DbProvider) *sql.Context {
   415  	return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider)))
   416  }
   417  
   418  func expectedResultSize(ctx *sql.Context, t1, t2 *plan.ResolvedTable, filters []sql.Expression, debug bool) (int, error) {
   419  	j := plan.NewJoin(t1, t2, plan.JoinTypeInner, expression.JoinAnd(filters...))
   420  	if debug {
   421  		fmt.Println(sql.DebugString(j))
   422  	}
   423  	i, err := rowexec.DefaultBuilder.Build(ctx, j, nil)
   424  	if err != nil {
   425  		return 0, err
   426  	}
   427  	cnt := 0
   428  	for {
   429  		_, err := i.Next(ctx)
   430  		if err == io.EOF {
   431  			break
   432  		}
   433  
   434  		if err != nil {
   435  			i.Close(ctx)
   436  			return 0, err
   437  		}
   438  		cnt++
   439  	}
   440  	return cnt, nil
   441  }
   442  
   443  func uniformDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int) error {
   444  	tab := rt.UnderlyingTable().(*memory.Table)
   445  	for i := 0; i < cnt; i++ {
   446  		row := sql.Row{int64(i), int64(i), int64(i)}
   447  		err := tab.Insert(ctx, row)
   448  		if err != nil {
   449  			return err
   450  		}
   451  	}
   452  	return nil
   453  }
   454  
   455  // TODO sample from exponential distribution
   456  func increasingHalfDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int) error {
   457  	tab := rt.UnderlyingTable().(*memory.Table)
   458  	for i := 0; i < 2*cnt; i++ {
   459  		for j := 0; j < (j/2)+1; j++ {
   460  			row := sql.Row{int64(i), int64(j), int64(j)}
   461  			err := tab.Insert(ctx, row)
   462  			if err != nil {
   463  				return err
   464  			}
   465  			i++
   466  		}
   467  	}
   468  	return nil
   469  }
   470  
   471  func expDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int, lambda float64) error {
   472  	tab := rt.UnderlyingTable().(*memory.Table)
   473  	iter := stats.NewExpDistIter(2, cnt, lambda)
   474  	var i int
   475  	for {
   476  		val, err := iter.Next(ctx)
   477  		if errors.Is(err, io.EOF) {
   478  			break
   479  		}
   480  		row := sql.Row{int64(val[0].(int))}
   481  		for _, v := range val[1:] {
   482  			row = append(row, int64(v.(float64)))
   483  		}
   484  		err = tab.Insert(ctx, row)
   485  		if err != nil {
   486  			return err
   487  		}
   488  		i++
   489  	}
   490  	return nil
   491  }
   492  
   493  func normalDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int, mean, std float64) error {
   494  	tab := rt.UnderlyingTable().(*memory.Table)
   495  	iter := stats.NewNormDistIter(2, cnt, mean, std)
   496  	var i int
   497  	for {
   498  		val, err := iter.Next(ctx)
   499  		if errors.Is(err, io.EOF) {
   500  			break
   501  		}
   502  		row := sql.Row{int64(i)}
   503  		for _, v := range val[1:] {
   504  			row = append(row, int64(v.(float64)))
   505  		}
   506  		err = tab.Insert(ctx, row)
   507  		if err != nil {
   508  			return err
   509  		}
   510  		i++
   511  	}
   512  	return nil
   513  }