github.com/dolthub/go-mysql-server@v0.18.0/enginetest/histogram_test.go (about) 1 package enginetest 2 3 import ( 4 "container/heap" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "log" 10 "sort" 11 "testing" 12 13 "github.com/stretchr/testify/require" 14 "golang.org/x/exp/rand" 15 16 "github.com/dolthub/go-mysql-server/memory" 17 "github.com/dolthub/go-mysql-server/sql" 18 "github.com/dolthub/go-mysql-server/sql/expression" 19 "github.com/dolthub/go-mysql-server/sql/plan" 20 "github.com/dolthub/go-mysql-server/sql/rowexec" 21 "github.com/dolthub/go-mysql-server/sql/stats" 22 "github.com/dolthub/go-mysql-server/sql/types" 23 ) 24 25 func init() { 26 rand.Seed(0) 27 } 28 29 type statsTest struct { 30 name string 31 tableGen func(*sql.Context, *memory.Database, int, sql.TableId, sql.ColumnId, ...interface{}) *plan.ResolvedTable 32 args1 []interface{} 33 args2 []interface{} 34 leftOrd []int 35 rightOrd []int 36 leftTypes []sql.Type 37 rightTypes []sql.Type 38 err float64 39 } 40 41 func TestNormDist(t *testing.T) { 42 tests := []struct { 43 name string 44 mean1 float64 45 std1 float64 46 mean2 float64 47 std2 float64 48 }{ 49 { 50 name: "same table", 51 mean1: 0, 52 std1: 10, 53 mean2: 0, 54 std2: 10, 55 }, 56 { 57 name: "same mean, different std", 58 mean1: 0, 59 std1: 10, 60 mean2: 0, 61 std2: 2, 62 }, 63 // 64 { 65 name: "similar mean, different std1", 66 mean1: 1, 67 std1: 10, 68 mean2: 0, 69 std2: 2, 70 }, 71 { 72 name: "same mean, different std2", 73 mean1: 0, 74 std1: 8, 75 mean2: 0, 76 std2: 2, 77 }, 78 { 79 name: "same mean, different std3", 80 mean1: 0, 81 std1: 7, 82 mean2: 0, 83 std2: 2, 84 }, 85 { 86 name: "similar mean, different std4", 87 mean1: 1, 88 std1: 7, 89 mean2: 0, 90 std2: 2, 91 }, 92 { 93 name: "similar mean, different std5", 94 mean1: 2, 95 std1: 7, 96 mean2: 0, 97 std2: 2, 98 }, 99 { 100 name: "similar mean, different std6", 101 mean1: 3, 102 std1: 7, 103 mean2: 0, 104 std2: 2, 105 }, 106 { 107 name: "similar mean, different std7", 108 mean1: 4, 109 std1: 7, 110 mean2: 0, 111 std2: 2, 112 }, 113 { 114 name: "similar mean, different std8", 115 mean1: 4, 116 std1: 7, 117 mean2: 0, 118 std2: 3, 119 }, 120 { 121 name: "similar mean, different std9", 122 mean1: 5, 123 std1: 7, 124 mean2: 0, 125 std2: 3, 126 }, 127 } 128 129 var statTests []statsTest 130 for _, t := range tests { 131 st := statsTest{ 132 name: t.name, 133 tableGen: func(ctx *sql.Context, db *memory.Database, cnt int, tab sql.TableId, col sql.ColumnId, args ...interface{}) *plan.ResolvedTable { 134 mean := args[0].(float64) 135 std := args[1].(float64) 136 xyz := makeTable(db, fmt.Sprintf("xyz%d", tab), tab, col) 137 err := normalDistForTable(ctx, xyz, cnt, mean, std) 138 if err != nil { 139 panic(err) 140 } 141 return xyz 142 }, 143 leftOrd: []int{1}, 144 rightOrd: []int{1}, 145 leftTypes: []sql.Type{types.Int64, types.Int64, types.Int64}, 146 rightTypes: []sql.Type{types.Int64, types.Int64, types.Int64}, 147 err: 1, 148 args1: []interface{}{t.mean1, t.std1}, 149 args2: []interface{}{t.mean2, t.std2}, 150 } 151 statTests = append(statTests, st) 152 } 153 154 debug := false 155 runStatsSuite(t, statTests, 100, 5, debug) 156 runStatsSuite(t, statTests, 100, 10, debug) 157 runStatsSuite(t, statTests, 100, 20, debug) 158 runStatsSuite(t, statTests, 500, 10, debug) 159 runStatsSuite(t, statTests, 500, 20, debug) 160 } 161 162 func TestExpDist(t *testing.T) { 163 tests := []struct { 164 name string 165 lambda1 float64 166 lambda2 float64 167 }{ 168 { 169 name: "same table", 170 lambda1: 1, 171 lambda2: 1, 172 }, 173 { 174 name: ".5/1.5", 175 lambda1: .5, 176 lambda2: 1.5, 177 }, 178 { 179 name: ".25/3", 180 lambda1: .25, 181 lambda2: 3, 182 }, 183 } 184 185 var statTests []statsTest 186 for _, tt := range tests { 187 st := statsTest{ 188 name: tt.name, 189 tableGen: func(ctx *sql.Context, db *memory.Database, cnt int, tab sql.TableId, col sql.ColumnId, args ...interface{}) *plan.ResolvedTable { 190 xyz := makeTable(db, "xyz", tab, col) 191 err := expDistForTable(ctx, xyz, cnt, args[0].(float64)) 192 if err != nil { 193 panic(err) 194 } 195 return xyz 196 }, 197 leftOrd: []int{1}, 198 rightOrd: []int{1}, 199 leftTypes: []sql.Type{types.Int64, types.Int64, types.Int64}, 200 rightTypes: []sql.Type{types.Int64, types.Int64, types.Int64}, 201 args1: []interface{}{tt.lambda1}, 202 args2: []interface{}{tt.lambda2}, 203 err: 1, 204 } 205 statTests = append(statTests, st) 206 } 207 208 debug := false 209 runStatsSuite(t, statTests, 100, 5, debug) 210 runStatsSuite(t, statTests, 100, 10, debug) 211 runStatsSuite(t, statTests, 100, 20, debug) 212 runStatsSuite(t, statTests, 500, 10, debug) 213 runStatsSuite(t, statTests, 500, 20, debug) 214 } 215 216 func TestMultiDist(t *testing.T) { 217 tests := []statsTest{ 218 { 219 name: "uniform dist int", 220 tableGen: func(ctx *sql.Context, db *memory.Database, cnt int, tab sql.TableId, col sql.ColumnId, args ...interface{}) *plan.ResolvedTable { 221 xyz := makeTable(db, "xyz", tab, col) 222 err := uniformDistForTable(ctx, xyz, cnt) 223 if err != nil { 224 panic(err) 225 } 226 return xyz 227 }, 228 leftOrd: []int{1}, 229 rightOrd: []int{1}, 230 leftTypes: []sql.Type{types.Int64, types.Int64, types.Int64}, 231 rightTypes: []sql.Type{types.Int64, types.Int64, types.Int64}, 232 err: .1, 233 }, 234 } 235 236 runStatsSuite(t, tests, 1000, 10, false) 237 } 238 239 // runStatsSuite will parse each statsTest and (1) generate 2 tables for a 240 // join, (2) compute histograms for the tables on the join index, (3) use 241 // the stats join algo to simulate a join estimate, and (4) compare the 242 // estimate to the actual result set count. 243 func runStatsSuite(t *testing.T, tests []statsTest, rowCnt, bucketCnt int, debug bool) { 244 for i, tt := range tests { 245 t.Run(fmt.Sprintf("%s: , rows: %d, buckets: %d", tt.name, rowCnt, bucketCnt), func(t *testing.T) { 246 db := memory.NewDatabase(fmt.Sprintf("test%d", i)) 247 pro := memory.NewDBProvider(db) 248 249 xyz := tt.tableGen(newContext(pro), db, rowCnt, sql.TableId(i*2), 1, tt.args1...) 250 wuv := tt.tableGen(newContext(pro), db, rowCnt, sql.TableId(i*2+1), sql.ColumnId(len(tt.leftTypes)+1), tt.args2...) 251 252 // join the histograms on a particular field 253 var joinOps []sql.Expression 254 for i, l := range tt.leftOrd { 255 r := tt.rightOrd[i] 256 joinOps = append(joinOps, eq(l, r+len(tt.leftTypes))) 257 } 258 259 exp, err := expectedResultSize(newContext(pro), xyz, wuv, joinOps, debug) 260 require.NoError(t, err) 261 262 // get histograms for tables 263 xHist, err := testHistogram(newContext(pro), xyz, tt.leftOrd, bucketCnt) 264 require.NoError(t, err) 265 266 wHist, err := testHistogram(newContext(pro), wuv, tt.rightOrd, bucketCnt) 267 require.NoError(t, err) 268 269 if debug { 270 log.Printf("xyz:\n%s\n", sql.Histogram(xHist).DebugString()) 271 log.Printf("wuv:\n%s\n", sql.Histogram(wHist).DebugString()) 272 } 273 274 lStat := &stats.Statistic{Typs: []sql.Type{types.Int64}} 275 for _, b := range xHist { 276 lStat.Hist = append(lStat.Hist, b.(*stats.Bucket)) 277 } 278 rStat := &stats.Statistic{Typs: []sql.Type{types.Int64}} 279 for _, b := range wHist { 280 rStat.Hist = append(rStat.Hist, b.(*stats.Bucket)) 281 } 282 283 res, err := stats.Join(stats.UpdateCounts(lStat), stats.UpdateCounts(rStat), 1, debug) 284 require.NoError(t, err) 285 if debug { 286 log.Printf("join %s\n", res.Histogram().DebugString()) 287 } 288 289 denom := float64(exp) 290 if cmp := float64(res.RowCount()); cmp > denom { 291 denom = cmp 292 } 293 delta := float64(exp-int(res.RowCount())) / denom 294 if delta < 0 { 295 delta = -delta 296 } 297 if debug { 298 log.Println(res.RowCount(), exp, delta) 299 } 300 301 // This compares the error percentage for our estimate to an 302 // error threshold specified in the statTest. The error bounds 303 // are loose and mostly useful for debugging at this point. 304 require.Less(t, delta, tt.err, "%d/%d/%.2f\nleft %s\nright %s", res.RowCount(), exp, delta, sql.Histogram(xHist).DebugString(), sql.Histogram(wHist).DebugString()) 305 }) 306 } 307 } 308 309 func testHistogram(ctx *sql.Context, table *plan.ResolvedTable, fields []int, buckets int) ([]sql.HistogramBucket, error) { 310 var cnt uint64 311 if st, ok := table.UnderlyingTable().(sql.StatisticsTable); ok { 312 var err error 313 cnt, _, err = st.RowCount(ctx) 314 if err != nil { 315 return nil, err 316 } 317 } 318 if cnt == 0 { 319 return nil, fmt.Errorf("found zero row count for table") 320 } 321 322 i, err := rowexec.DefaultBuilder.Build(ctx, table, nil) 323 rows, err := sql.RowIterToRows(ctx, i) 324 if err != nil { 325 return nil, err 326 } 327 328 sch := table.Schema() 329 330 keyVals := make([]sql.Row, len(rows)) 331 for i, row := range rows { 332 for _, ord := range fields { 333 keyVals[i] = append(keyVals[i], row[ord]) 334 } 335 } 336 337 cmp := func(i, j int) int { 338 k := 0 339 for k < len(fields) && keyVals[i][k] == keyVals[j][k] { 340 k++ 341 } 342 if k == len(fields) { 343 return 0 344 } 345 col := sch[fields[k]] 346 cmp, _ := col.Type.Compare(keyVals[i][k], keyVals[j][k]) 347 return cmp 348 } 349 350 sort.Slice(keyVals, func(i, j int) bool { return cmp(i, j) <= 0 }) 351 352 var types []sql.Type 353 for _, i := range fields { 354 types = append(types, sch[i].Type) 355 } 356 357 var histogram []sql.HistogramBucket 358 rowsPerBucket := int(cnt) / buckets 359 currentBucket := &stats.Bucket{DistinctCnt: 1} 360 mcvCnt := 3 361 currentCnt := 0 362 mcvs := stats.NewSqlHeap(mcvCnt) 363 for i, row := range keyVals { 364 currentCnt++ 365 currentBucket.RowCnt++ 366 if i > 0 { 367 if cmp(i, i-1) != 0 { 368 currentBucket.DistinctCnt++ 369 heap.Push(mcvs, stats.NewHeapRow(keyVals[i-1], currentCnt)) 370 currentCnt = 1 371 } 372 } 373 if currentBucket.RowCnt > uint64(rowsPerBucket) { 374 currentBucket.BoundVal = row 375 currentBucket.BoundCnt = uint64(currentCnt) 376 heap.Push(mcvs, stats.NewHeapRow(row, currentCnt)) 377 currentBucket.McvVals = mcvs.Array() 378 currentBucket.McvsCnt = mcvs.Counts() 379 histogram = append(histogram, currentBucket) 380 currentBucket = &stats.Bucket{DistinctCnt: 1} 381 mcvs = stats.NewSqlHeap(mcvCnt) 382 currentCnt = 0 383 } 384 } 385 currentBucket.BoundVal = keyVals[len(keyVals)-1] 386 currentBucket.BoundCnt = uint64(currentCnt) 387 currentBucket.McvVals = mcvs.Array() 388 currentBucket.McvsCnt = mcvs.Counts() 389 histogram = append(histogram, currentBucket) 390 return histogram, nil 391 } 392 393 func eq(idx1, idx2 int) *expression.Equals { 394 return expression.NewEquals( 395 expression.NewGetField(idx1, types.Int64, "", false), 396 expression.NewGetField(idx2, types.Int64, "", false)) 397 } 398 399 func childSchema(source string) sql.PrimaryKeySchema { 400 return sql.NewPrimaryKeySchema(sql.Schema{ 401 {Name: "x", Source: source, Type: types.Int64, Nullable: false}, 402 {Name: "y", Source: source, Type: types.Int64, Nullable: true}, 403 {Name: "z", Source: source, Type: types.Int64, Nullable: true}, 404 }, 0) 405 } 406 407 func makeTable(db *memory.Database, name string, tabId sql.TableId, colId sql.ColumnId) *plan.ResolvedTable { 408 t := memory.NewTable(db, name, childSchema(name), nil) 409 t.EnablePrimaryKeyIndexes() 410 colset := sql.NewColSet(sql.ColumnId(colId), sql.ColumnId(colId+1), sql.ColumnId(colId+2)) 411 return plan.NewResolvedTable(t, db, nil).WithId(sql.TableId(tabId)).WithColumns(colset).(*plan.ResolvedTable) 412 } 413 414 func newContext(provider *memory.DbProvider) *sql.Context { 415 return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider))) 416 } 417 418 func expectedResultSize(ctx *sql.Context, t1, t2 *plan.ResolvedTable, filters []sql.Expression, debug bool) (int, error) { 419 j := plan.NewJoin(t1, t2, plan.JoinTypeInner, expression.JoinAnd(filters...)) 420 if debug { 421 fmt.Println(sql.DebugString(j)) 422 } 423 i, err := rowexec.DefaultBuilder.Build(ctx, j, nil) 424 if err != nil { 425 return 0, err 426 } 427 cnt := 0 428 for { 429 _, err := i.Next(ctx) 430 if err == io.EOF { 431 break 432 } 433 434 if err != nil { 435 i.Close(ctx) 436 return 0, err 437 } 438 cnt++ 439 } 440 return cnt, nil 441 } 442 443 func uniformDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int) error { 444 tab := rt.UnderlyingTable().(*memory.Table) 445 for i := 0; i < cnt; i++ { 446 row := sql.Row{int64(i), int64(i), int64(i)} 447 err := tab.Insert(ctx, row) 448 if err != nil { 449 return err 450 } 451 } 452 return nil 453 } 454 455 // TODO sample from exponential distribution 456 func increasingHalfDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int) error { 457 tab := rt.UnderlyingTable().(*memory.Table) 458 for i := 0; i < 2*cnt; i++ { 459 for j := 0; j < (j/2)+1; j++ { 460 row := sql.Row{int64(i), int64(j), int64(j)} 461 err := tab.Insert(ctx, row) 462 if err != nil { 463 return err 464 } 465 i++ 466 } 467 } 468 return nil 469 } 470 471 func expDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int, lambda float64) error { 472 tab := rt.UnderlyingTable().(*memory.Table) 473 iter := stats.NewExpDistIter(2, cnt, lambda) 474 var i int 475 for { 476 val, err := iter.Next(ctx) 477 if errors.Is(err, io.EOF) { 478 break 479 } 480 row := sql.Row{int64(val[0].(int))} 481 for _, v := range val[1:] { 482 row = append(row, int64(v.(float64))) 483 } 484 err = tab.Insert(ctx, row) 485 if err != nil { 486 return err 487 } 488 i++ 489 } 490 return nil 491 } 492 493 func normalDistForTable(ctx *sql.Context, rt *plan.ResolvedTable, cnt int, mean, std float64) error { 494 tab := rt.UnderlyingTable().(*memory.Table) 495 iter := stats.NewNormDistIter(2, cnt, mean, std) 496 var i int 497 for { 498 val, err := iter.Next(ctx) 499 if errors.Is(err, io.EOF) { 500 break 501 } 502 row := sql.Row{int64(i)} 503 for _, v := range val[1:] { 504 row = append(row, int64(v.(float64))) 505 } 506 err = tab.Insert(ctx, row) 507 if err != nil { 508 return err 509 } 510 i++ 511 } 512 return nil 513 }