github.com/dolthub/go-mysql-server@v0.18.0/sql/stats/join.go (about) 1 package stats 2 3 import ( 4 "container/heap" 5 "fmt" 6 "log" 7 "math" 8 "time" 9 10 "github.com/pkg/errors" 11 12 "github.com/dolthub/go-mysql-server/sql" 13 "github.com/dolthub/go-mysql-server/sql/types" 14 ) 15 16 var ErrJoinStringStatistics = errors.New("joining string histograms is unsupported") 17 18 // Join performs an alignment algorithm on two sets of statistics, and 19 // then pairwise estimates bucket cardinalities by joining most common 20 // values (mcvs) directly and assuming key uniformity otherwise. Only 21 // numeric types are supported. 22 func Join(s1, s2 sql.Statistic, prefixCnt int, debug bool) (sql.Statistic, error) { 23 cmp := func(row1, row2 sql.Row) (int, error) { 24 var keyCmp int 25 for i := 0; i < prefixCnt; i++ { 26 k1, _, err := s1.Types()[i].Promote().Convert(row1[i]) 27 if err != nil { 28 return 0, fmt.Errorf("incompatible types") 29 } 30 31 k2, _, err := s2.Types()[i].Promote().Convert(row2[i]) 32 if err != nil { 33 return 0, fmt.Errorf("incompatible types") 34 } 35 36 cmp, err := s1.Types()[i].Promote().Compare(k1, k2) 37 if err != nil { 38 return 0, err 39 } 40 if cmp == 0 { 41 continue 42 } 43 keyCmp = cmp 44 break 45 } 46 return keyCmp, nil 47 } 48 49 s1Buckets, err := mergeOverlappingBuckets(s1.Histogram(), s1.Types()) 50 if err != nil { 51 return nil, err 52 } 53 s2Buckets, err := mergeOverlappingBuckets(s2.Histogram(), s2.Types()) 54 if err != nil { 55 return nil, err 56 } 57 58 s1AliHist, s2AliHist, err := AlignBuckets(s1Buckets, s2Buckets, s1.LowerBound(), s2.LowerBound(), s1.Types()[:prefixCnt], s2.Types()[:prefixCnt], cmp) 59 if err != nil { 60 return nil, err 61 } 62 if debug { 63 log.Println("left", s1AliHist.DebugString()) 64 log.Println("right", s2AliHist.DebugString()) 65 } 66 67 newHist, err := joinAlignedStats(s1AliHist, s2AliHist, cmp) 68 ret := NewStatistic(0, 0, 0, s1.AvgSize(), time.Now(), s1.Qualifier(), s1.Columns(), s1.Types(), newHist, s1.IndexClass(), nil) 69 return UpdateCounts(ret), nil 70 } 71 72 // joinAlignedStats assumes |left| and |right| have the same number of 73 // buckets to estimate the join cardinality. Most common values (mcvs) adjust 74 // the estimates to account for outlier keys that are a disproportionately 75 // high fraction of the index. 76 func joinAlignedStats(left, right sql.Histogram, cmp func(sql.Row, sql.Row) (int, error)) ([]*Bucket, error) { 77 var newBuckets []*Bucket 78 newCnt := uint64(0) 79 for i := range left { 80 l := left[i] 81 r := right[i] 82 lDistinct := float64(l.DistinctCount()) 83 rDistinct := float64(r.DistinctCount()) 84 85 lRows := float64(l.RowCount()) 86 rRows := float64(r.RowCount()) 87 88 var rows uint64 89 90 // mcvs counted in isolation 91 // todo: should we assume non-match MCVs in smaller set 92 // contribute MCV count * average frequency from the larger? 93 var mcvMatch int 94 for i, key1 := range l.Mcvs() { 95 for j, key2 := range r.Mcvs() { 96 v, err := cmp(key1, key2) 97 if err != nil { 98 return nil, err 99 } 100 if v == 0 { 101 rows += l.McvCounts()[i] * r.McvCounts()[j] 102 lRows -= float64(l.McvCounts()[i]) 103 rRows -= float64(r.McvCounts()[j]) 104 lDistinct-- 105 rDistinct-- 106 mcvMatch++ 107 break 108 } 109 } 110 } 111 112 // true up negative approximations 113 lRows = math.Max(lRows, 0) 114 rRows = math.Max(rRows, 0) 115 lDistinct = math.Max(lDistinct, 0) 116 rDistinct = math.Max(rDistinct, 0) 117 118 // Selinger method on rest of buckets 119 maxDistinct := math.Max(lDistinct, rDistinct) 120 minDistinct := math.Min(lDistinct, rDistinct) 121 122 if maxDistinct > 0 { 123 rows += uint64(float64(lRows*rRows) / float64(maxDistinct)) 124 } 125 126 newCnt += rows 127 128 // TODO: something smarter with MCVs 129 mcvs := append(l.Mcvs(), r.Mcvs()...) 130 mcvCounts := append(l.McvCounts(), r.McvCounts()...) 131 132 newBucket := NewHistogramBucket( 133 rows, 134 uint64(minDistinct)+uint64(mcvMatch), // matched mcvs contribute back to result distinct count 135 uint64(float64(l.NullCount()*r.NullCount())/float64(maxDistinct)), 136 l.BoundCount()*r.BoundCount(), l.UpperBound(), mcvCounts, mcvs) 137 newBuckets = append(newBuckets, newBucket) 138 } 139 return newBuckets, nil 140 } 141 142 // AlignBuckets produces two histograms with the same number of buckets. 143 // Start by using upper bound keys to truncate histogram with a larger 144 // keyspace. Then for every misaligned pair of buckets, cut the one with the 145 // higher bound value on the smaller's key. We use a linear interpolation 146 // to divide keys when splitting. 147 func AlignBuckets(h1, h2 sql.Histogram, lBound1, lBound2 sql.Row, s1Types, s2Types []sql.Type, cmp func(sql.Row, sql.Row) (int, error)) (sql.Histogram, sql.Histogram, error) { 148 var numericTypes bool = true 149 for _, t := range s1Types { 150 if _, ok := t.(sql.NumberType); !ok { 151 numericTypes = false 152 break 153 } 154 } 155 156 if !numericTypes { 157 // todo(max): distance between two strings is difficult, 158 // but we could cut equal fractions depending on total 159 // cuts for a bucket 160 return nil, nil, ErrJoinStringStatistics 161 } 162 163 var leftRes sql.Histogram 164 var rightRes sql.Histogram 165 var leftStack []sql.HistogramBucket 166 var rightStack []sql.HistogramBucket 167 var nextL sql.HistogramBucket 168 var nextR sql.HistogramBucket 169 var keyCmp int 170 var err error 171 var reverse bool 172 173 swap := func() { 174 leftStack, rightStack = rightStack, leftStack 175 nextL, nextR = nextR, nextL 176 leftRes, rightRes = rightRes, leftRes 177 h1, h2 = h2, h1 178 reverse = !reverse 179 } 180 181 var state sjState = sjStateInit 182 for state != sjStateEOF { 183 switch state { 184 case sjStateInit: 185 // Merge adjacent overlapping buckets within each histogram. 186 // Truncate non-overlapping tail buckets between left and right. 187 // Reverse the buckets into stacks. 188 189 s1Hist, err := mergeOverlappingBuckets(h1, s1Types) 190 if err != nil { 191 return nil, nil, err 192 } 193 s2Hist, err := mergeOverlappingBuckets(h2, s2Types) 194 if err != nil { 195 return nil, nil, err 196 } 197 198 s1Last := s1Hist[len(s1Hist)-1].UpperBound() 199 s2Last := s2Hist[len(s2Hist)-1].UpperBound() 200 idx1, err := PrefixLtHist(s1Hist, s2Last, cmp) 201 if err != nil { 202 return nil, nil, err 203 } 204 idx2, err := PrefixLtHist(s2Hist, s1Last, cmp) 205 if err != nil { 206 return nil, nil, err 207 } 208 if idx1 < len(s1Hist) { 209 idx1++ 210 } 211 if idx2 < len(s2Hist) { 212 idx2++ 213 } 214 s1Hist = s1Hist[:idx1] 215 s2Hist = s2Hist[:idx2] 216 217 if lBound2 != nil { 218 idx, err := PrefixGteHist(s1Hist, lBound2, cmp) 219 if err != nil { 220 return nil, nil, err 221 } 222 s1Hist = s1Hist[idx:] 223 } 224 if lBound1 != nil { 225 idx, err := PrefixGteHist(s2Hist, lBound1, cmp) 226 if err != nil { 227 return nil, nil, err 228 } 229 s2Hist = s2Hist[idx:] 230 } 231 232 if len(s1Hist) == 0 || len(s2Hist) == 0 { 233 return nil, nil, nil 234 } 235 236 if len(s1Hist) == 0 || len(s2Hist) == 0 { 237 return nil, nil, nil 238 } 239 240 m := len(s1Hist) - 1 241 leftStack = make([]sql.HistogramBucket, m) 242 for i, b := range s1Hist { 243 if i == 0 { 244 nextL = b 245 continue 246 } 247 leftStack[m-i] = b 248 } 249 250 n := len(s2Hist) - 1 251 rightStack = make([]sql.HistogramBucket, n) 252 for i, b := range s2Hist { 253 if i == 0 { 254 nextR = b 255 continue 256 } 257 rightStack[n-i] = b 258 } 259 260 state = sjStateCmp 261 262 case sjStateCmp: 263 keyCmp, err = cmp(nextL.UpperBound(), nextR.UpperBound()) 264 if err != nil { 265 return nil, nil, err 266 } 267 switch keyCmp { 268 case 0: 269 state = sjStateInc 270 case 1: 271 state = sjStateCutLeft 272 case -1: 273 state = sjStateCutRight 274 } 275 276 case sjStateCutLeft: 277 // default cuts left 278 state = sjStateCut 279 280 case sjStateCutRight: 281 // switch to make left the cut target 282 swap() 283 state = sjStateCut 284 285 case sjStateCut: 286 state = sjStateInc 287 // The left bucket is longer than the right bucket. 288 // In the default case, we will cut the left bucket on 289 // the right boundary, and put the right remainder back 290 // on the stack. 291 292 if len(leftRes) == 0 { 293 // It is difficult to cut the first bucket because the 294 // lower bound is negative infinity. We instead extend the 295 // smaller side (right) by stealing form its precedeccors 296 // up to the left cutpoint. 297 298 if len(rightStack) == 0 { 299 continue 300 } 301 302 var peekR sql.HistogramBucket 303 for len(rightStack) > 0 { 304 // several right buckets might be less than the left cutpoint 305 peekR = rightStack[len(rightStack)-1] 306 rightStack = rightStack[:len(rightStack)-1] 307 keyCmp, err = cmp(peekR.UpperBound(), nextL.UpperBound()) 308 if err != nil { 309 return nil, nil, err 310 } 311 if keyCmp > 0 { 312 break 313 } 314 315 nextR = NewHistogramBucket( 316 uint64(float64(nextR.RowCount())+float64(peekR.RowCount())), 317 uint64(float64(nextR.DistinctCount())+float64(peekR.DistinctCount())), 318 uint64(float64(nextR.NullCount())+float64(peekR.NullCount())), 319 peekR.BoundCount(), peekR.UpperBound(), peekR.McvCounts(), peekR.Mcvs()) 320 } 321 322 // nextR < nextL < peekR 323 bucketMagnitude, err := euclideanDistance(nextR.UpperBound(), peekR.UpperBound(), len(s1Types)) 324 if err != nil { 325 return nil, nil, err 326 } 327 328 if bucketMagnitude == 0 { 329 peekR = nil 330 continue 331 } 332 333 // estimate midpoint 334 cutMagnitude, err := euclideanDistance(nextR.UpperBound(), nextL.UpperBound(), len(s1Types)) 335 if err != nil { 336 return nil, nil, err 337 } 338 339 cutFrac := cutMagnitude / bucketMagnitude 340 341 // lastL -> nextR 342 firstHalf := NewHistogramBucket( 343 uint64(float64(nextR.RowCount())+float64(peekR.RowCount())*cutFrac), 344 uint64(float64(nextR.DistinctCount())+float64(peekR.DistinctCount())*cutFrac), 345 uint64(float64(nextR.NullCount())+float64(peekR.NullCount())*cutFrac), 346 1, nextL.UpperBound(), nil, nil) 347 348 // nextR -> nextL 349 secondHalf := NewHistogramBucket( 350 uint64(float64(peekR.RowCount())*(1-cutFrac)), 351 uint64(float64(peekR.DistinctCount())*(1-cutFrac)), 352 uint64(float64(peekR.NullCount())*(1-cutFrac)), 353 peekR.BoundCount(), 354 peekR.UpperBound(), 355 peekR.McvCounts(), 356 peekR.Mcvs()) 357 358 nextR = firstHalf 359 rightStack = append(rightStack, secondHalf) 360 continue 361 } 362 363 // get left "distance" 364 bucketMagnitude, err := euclideanDistance(nextL.UpperBound(), leftRes[len(leftRes)-1].UpperBound(), len(s1Types)) 365 if err != nil { 366 return nil, nil, err 367 } 368 369 // estimate midpoint 370 cutMagnitude, err := euclideanDistance(nextL.UpperBound(), nextR.UpperBound(), len(s1Types)) 371 if err != nil { 372 return nil, nil, err 373 } 374 375 cutFrac := cutMagnitude / bucketMagnitude 376 377 // lastL -> nextR 378 firstHalf := NewHistogramBucket( 379 uint64(float64(nextL.RowCount())*(1-cutFrac)), 380 uint64(float64(nextL.DistinctCount())*(1-cutFrac)), 381 uint64(float64(nextL.NullCount())*(1-cutFrac)), 382 1, nextR.UpperBound(), nil, nil) 383 384 // nextR -> nextL 385 secondHalf := NewHistogramBucket( 386 uint64(float64(nextL.RowCount())*cutFrac), 387 uint64(float64(nextL.DistinctCount())*cutFrac), 388 uint64(float64(nextL.NullCount())*cutFrac), 389 nextL.BoundCount(), 390 nextL.UpperBound(), 391 nextL.McvCounts(), 392 nextL.Mcvs()) 393 394 nextL = firstHalf 395 leftStack = append(leftStack, secondHalf) 396 397 case sjStateInc: 398 leftRes = append(leftRes, nextL) 399 rightRes = append(rightRes, nextR) 400 401 nextL = nil 402 nextR = nil 403 404 if len(leftStack) > 0 { 405 nextL = leftStack[len(leftStack)-1] 406 leftStack = leftStack[:len(leftStack)-1] 407 } 408 if len(rightStack) > 0 { 409 nextR = rightStack[len(rightStack)-1] 410 rightStack = rightStack[:len(rightStack)-1] 411 } 412 413 state = sjStateCmp 414 415 if nextL == nil || nextR == nil { 416 state = sjStateExhaust 417 } 418 419 case sjStateExhaust: 420 state = sjStateEOF 421 422 if nextL == nil && nextR == nil { 423 continue 424 } 425 426 if nextL == nil { 427 // swap so right side is nil 428 swap() 429 } 430 431 // squash the trailing buckets into one 432 // TODO: cut the left side on the right's final bound when there is >1 left 433 leftStack = append(leftStack, nextL) 434 nextL = leftRes[len(leftRes)-1] 435 leftRes = leftRes[:len(leftRes)-1] 436 for len(leftStack) > 0 { 437 peekL := leftStack[len(leftStack)-1] 438 leftStack = leftStack[:len(leftStack)-1] 439 nextL = NewHistogramBucket( 440 uint64(float64(nextL.RowCount())+float64(peekL.RowCount())), 441 uint64(float64(nextL.DistinctCount())+float64(peekL.DistinctCount())), 442 uint64(float64(nextL.NullCount())+float64(peekL.NullCount())), 443 peekL.BoundCount(), peekL.UpperBound(), peekL.McvCounts(), peekL.Mcvs()) 444 } 445 leftRes = append(leftRes, nextL) 446 nextL = nil 447 448 } 449 } 450 451 if reverse { 452 leftRes, rightRes = rightRes, leftRes 453 } 454 return leftRes, rightRes, nil 455 } 456 457 // mergeMcvs combines two sets of most common values, merging the bound keys 458 // with the same value and keeping the top k of the merge result. 459 func mergeMcvs(mcvs1, mcvs2 []sql.Row, mcvCnts1, mcvCnts2 []uint64, cmp func(sql.Row, sql.Row) (int, error)) ([]sql.Row, []uint64, error) { 460 if len(mcvs1) < len(mcvs2) { 461 // mcvs2 is low 462 mcvs1, mcvs2 = mcvs2, mcvs1 463 mcvCnts1, mcvCnts2 = mcvCnts2, mcvCnts1 464 } 465 if len(mcvs2) == 0 { 466 return mcvs1, mcvCnts1, nil 467 } 468 469 ret := NewSqlHeap(len(mcvs2)) 470 seen := make(map[int]bool) 471 for i, row1 := range mcvs1 { 472 matched := -1 473 for j, row2 := range mcvs2 { 474 c, err := cmp(row1, row2) 475 if err != nil { 476 return nil, nil, err 477 } 478 if c == 0 { 479 matched = j 480 break 481 } 482 } 483 if matched > 0 { 484 seen[matched] = true 485 heap.Push(ret, NewHeapRow(mcvs1[i], int(mcvCnts1[i]+mcvCnts2[matched]))) 486 } else { 487 heap.Push(ret, NewHeapRow(mcvs1[i], int(mcvCnts1[i]))) 488 } 489 } 490 for j := range mcvs2 { 491 if !seen[j] { 492 heap.Push(ret, NewHeapRow(mcvs2[j], int(mcvCnts2[j]))) 493 494 } 495 } 496 return ret.Array(), ret.Counts(), nil 497 } 498 499 // mergeOverlappingBuckets folds bins with one element into the previous 500 // bucket when the bound keys match. 501 func mergeOverlappingBuckets(h sql.Histogram, types []sql.Type) (sql.Histogram, error) { 502 cmp := func(l, r sql.Row) (int, error) { 503 for i := 0; i < len(types); i++ { 504 cmp, err := types[i].Compare(l[i], r[i]) 505 if err != nil { 506 return 0, err 507 } 508 switch cmp { 509 case 0: 510 continue 511 case -1: 512 return -1, nil 513 case 1: 514 return 1, nil 515 } 516 } 517 return 0, nil 518 } 519 // |k| is the write position, |i| is the compare position 520 // |k| <= |i| 521 i := 0 522 k := 0 523 for i < len(h) { 524 h[k] = h[i] 525 i++ 526 if i >= len(h) { 527 k++ 528 break 529 } 530 mcvs, mcvCnts, err := mergeMcvs(h[i].Mcvs(), h[i-1].Mcvs(), h[i].McvCounts(), h[i-1].McvCounts(), cmp) 531 if err != nil { 532 return nil, err 533 } 534 for ; i < len(h) && h[i].DistinctCount() == 1; i++ { 535 eq, err := cmp(h[k].UpperBound(), h[i].UpperBound()) 536 if err != nil { 537 return nil, err 538 } 539 if eq != 0 { 540 break 541 } 542 h[k] = NewHistogramBucket( 543 h[k].RowCount()+h[i].RowCount(), 544 h[k].DistinctCount(), 545 h[k].NullCount()+h[i].NullCount(), 546 h[k].BoundCount()+h[i].BoundCount(), 547 h[k].UpperBound(), 548 mcvCnts, 549 mcvs) 550 } 551 k++ 552 } 553 return h[:k], nil 554 } 555 556 type sjState int8 557 558 const ( 559 sjStateUnknown = iota 560 sjStateInit 561 sjStateCmp 562 sjStateCutLeft 563 sjStateCutRight 564 sjStateCut 565 sjStateInc 566 sjStateExhaust 567 sjStateEOF 568 ) 569 570 // euclideanDistance is a vectorwise sum of squares distance between 571 // two numeric types. 572 func euclideanDistance(row1, row2 sql.Row, prefixLen int) (float64, error) { 573 var distSq float64 574 for i := 0; i < prefixLen; i++ { 575 v1, _, err := types.Float64.Convert(row1[i]) 576 if err != nil { 577 return 0, err 578 } 579 v2, _, err := types.Float64.Convert(row2[i]) 580 if err != nil { 581 return 0, err 582 } 583 f1 := v1.(float64) 584 f2 := v2.(float64) 585 distSq += f1*f1 - 2*f1*f2 + f2*f2 586 } 587 return math.Sqrt(distSq), nil 588 }