github.com/dolthub/go-mysql-server@v0.18.0/memory/stats.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memory
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"math"
    22  	"math/rand"
    23  	"sort"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql/stats"
    28  
    29  	"github.com/dolthub/go-mysql-server/sql"
    30  )
    31  
    32  func NewStatsProv() *StatsProv {
    33  	return &StatsProv{
    34  		colStats: make(map[statsKey]sql.Statistic),
    35  	}
    36  }
    37  
    38  type statsKey string
    39  
    40  type StatsProv struct {
    41  	colStats map[statsKey]sql.Statistic
    42  }
    43  
    44  var _ sql.StatsProvider = (*StatsProv)(nil)
    45  
    46  func (s *StatsProv) RefreshTableStats(ctx *sql.Context, table sql.Table, db string) error {
    47  	// non-Dolt would sample the table to get estimate of unique and histogram
    48  	iat, ok := table.(sql.IndexAddressableTable)
    49  	if !ok {
    50  		return nil
    51  	}
    52  	indexes, err := iat.GetIndexes(ctx)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	ordinals := make(map[string]int)
    58  	for i, c := range table.Schema() {
    59  		ordinals[strings.ToLower(c.Name)] = i
    60  	}
    61  
    62  	newStats := make(map[statsKey][]int)
    63  	tablePrefix := fmt.Sprintf("%s.", strings.ToLower(table.Name()))
    64  	for _, idx := range indexes {
    65  		cols := make([]string, len(idx.Expressions()))
    66  		for i, c := range idx.Expressions() {
    67  			cols[i] = strings.TrimPrefix(strings.ToLower(c), tablePrefix)
    68  		}
    69  		for i := 1; i < len(cols)+1; i++ {
    70  			pref := cols[:i]
    71  			key := statsKey(fmt.Sprintf("%s.%s.%s.(%s)", strings.ToLower(db), strings.ToLower(idx.Table()), strings.ToLower(idx.ID()), strings.Join(pref, ",")))
    72  			if _, ok := newStats[key]; !ok {
    73  				ords := make([]int, len(pref))
    74  				for i, c := range pref {
    75  					ords[i] = ordinals[c]
    76  				}
    77  				newStats[key] = ords
    78  			}
    79  		}
    80  	}
    81  	return s.estimateStats(ctx, table, newStats)
    82  }
    83  
    84  func (s *StatsProv) estimateStats(ctx *sql.Context, table sql.Table, keys map[statsKey][]int) error {
    85  	sample, err := s.reservoirSample(ctx, table)
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	var dataLen uint64
    91  	var rowCount uint64
    92  	if statsTab, ok := table.(sql.StatisticsTable); ok {
    93  		rowCount, _, err = statsTab.RowCount(ctx)
    94  		if err != nil {
    95  			return err
    96  		}
    97  		dataLen, err = statsTab.DataLength(ctx)
    98  		if err != nil {
    99  			return err
   100  		}
   101  	}
   102  
   103  	indexes := make(map[string]sql.Index)
   104  	if iat, ok := table.(sql.IndexAddressableTable); ok {
   105  		idxs, err := iat.GetIndexes(ctx)
   106  		if err != nil {
   107  			return err
   108  		}
   109  		for _, idx := range idxs {
   110  			indexes[strings.ToLower(idx.ID())] = idx
   111  		}
   112  	}
   113  
   114  	sch := table.Schema()
   115  	for key, ordinals := range keys {
   116  		keyVals := make([]sql.Row, len(sample))
   117  		for i, row := range sample {
   118  			for _, ord := range ordinals {
   119  				keyVals[i] = append(keyVals[i], row[ord])
   120  			}
   121  		}
   122  		sort.Slice(keyVals, func(i, j int) bool {
   123  			k := 0
   124  			for k < len(ordinals) && keyVals[i][k] == keyVals[j][k] {
   125  				k++
   126  			}
   127  			if k == len(ordinals) {
   128  				return true
   129  			}
   130  			col := sch[ordinals[k]]
   131  			cmp, _ := col.Type.Compare(keyVals[i][k], keyVals[j][k])
   132  			return cmp <= 0
   133  		})
   134  
   135  		// quick and dirty histogram buckets
   136  		bucketCnt := 20
   137  		if len(keyVals) < bucketCnt {
   138  			bucketCnt = len(keyVals)
   139  		}
   140  		offset := len(keyVals) / bucketCnt
   141  		perBucket := int(rowCount) / bucketCnt
   142  		buckets := make([]*stats.Bucket, bucketCnt)
   143  		for i := range buckets {
   144  			var upperBound []interface{}
   145  			for _, v := range keyVals[i*offset] {
   146  				upperBound = append(upperBound, v)
   147  			}
   148  			buckets[i] = stats.NewHistogramBucket(uint64(perBucket), uint64(perBucket), 0, 1, upperBound, nil, nil)
   149  		}
   150  
   151  		// columns and types
   152  		var cols []string
   153  		var types []sql.Type
   154  		for _, i := range ordinals {
   155  			cols = append(cols, sch[i].Name)
   156  			types = append(types, sch[i].Type)
   157  		}
   158  
   159  		qual, err := sql.NewQualifierFromString(string(key))
   160  		if err != nil {
   161  			return err
   162  		}
   163  
   164  		stat := stats.NewStatistic(rowCount, rowCount, 0, dataLen, time.Now(), qual, cols, types, buckets, sql.IndexClassDefault, nil)
   165  
   166  		// functional dependencies
   167  		fds, idxCols, err := stats.IndexFds(table.Name(), sch, indexes[strings.ToLower(qual.Index())])
   168  		if err != nil {
   169  			return err
   170  		}
   171  		ret := stat.WithFuncDeps(fds)
   172  		ret = ret.WithColSet(idxCols)
   173  		s.colStats[key] = ret
   174  	}
   175  	return nil
   176  }
   177  
   178  // reservoirSample selects a random subset of values from the table.
   179  // Algorithm L from: https://dl.acm.org/doi/pdf/10.1145/198429.198435
   180  func (s *StatsProv) reservoirSample(ctx *sql.Context, table sql.Table) ([]sql.Row, error) {
   181  	// read through table
   182  	var maxQueue float64 = 4000
   183  	var queue []sql.Row
   184  	partIter, err := table.Partitions(ctx)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  
   189  	updateW := func(w float64) float64 {
   190  		return w * math.Exp(math.Log(rand.Float64())/maxQueue)
   191  	}
   192  	updateI := func(i int, w float64) int {
   193  		return i + int(math.Floor(math.Log(rand.Float64())/math.Log(1-w))) + 1
   194  	}
   195  	w := updateW(1)
   196  	i := updateI(0, w)
   197  	j := 0
   198  	for {
   199  		part, err := partIter.Next(ctx)
   200  		if errors.Is(err, io.EOF) {
   201  			break
   202  		} else if err != nil {
   203  			return nil, err
   204  		}
   205  		rowIter, err := table.PartitionRows(ctx, part)
   206  		if err != nil {
   207  			return nil, err
   208  		}
   209  		for {
   210  			row, err := rowIter.Next(ctx)
   211  			if errors.Is(err, io.EOF) {
   212  				break
   213  			} else if err != nil {
   214  				return nil, err
   215  			}
   216  			if len(queue) < int(maxQueue) {
   217  				queue = append(queue, row)
   218  				j++
   219  				continue
   220  			}
   221  
   222  			if j == i {
   223  				// random swap
   224  				pos := rand.Intn(int(maxQueue))
   225  				queue[pos] = row
   226  				// update
   227  				w = updateW(w)
   228  				i = updateI(i, w)
   229  			}
   230  			j++
   231  		}
   232  	}
   233  	return queue, nil
   234  }
   235  
   236  func (s *StatsProv) GetTableStats(ctx *sql.Context, db, table string) ([]sql.Statistic, error) {
   237  	pref := fmt.Sprintf("%s.%s", strings.ToLower(db), strings.ToLower(table))
   238  	var ret []sql.Statistic
   239  	for key, stats := range s.colStats {
   240  		if strings.HasPrefix(string(key), pref) {
   241  			ret = append(ret, stats)
   242  		}
   243  	}
   244  	return ret, nil
   245  }
   246  
   247  func (s *StatsProv) SetStats(ctx *sql.Context, stats sql.Statistic) error {
   248  	key := statsKey(fmt.Sprintf("%s.(%s)", stats.Qualifier(), strings.Join(stats.Columns(), ",")))
   249  	s.colStats[key] = stats
   250  	return nil
   251  }
   252  
   253  func (s *StatsProv) GetStats(ctx *sql.Context, qual sql.StatQualifier, cols []string) (sql.Statistic, bool) {
   254  	key := statsKey(fmt.Sprintf("%s.(%s)", qual, strings.Join(cols, ",")))
   255  	if stats, ok := s.colStats[key]; ok {
   256  		return stats, false
   257  	}
   258  	return nil, false
   259  }
   260  
   261  func (s *StatsProv) DropStats(ctx *sql.Context, qual sql.StatQualifier, cols []string) error {
   262  	colsSuff := strings.Join(cols, ",") + ")"
   263  	for key, _ := range s.colStats {
   264  		if strings.HasPrefix(string(key), qual.String()) && strings.HasSuffix(string(key), colsSuff) {
   265  			delete(s.colStats, key)
   266  		}
   267  	}
   268  	return nil
   269  }
   270  
   271  func (s *StatsProv) RowCount(ctx *sql.Context, db, table string) (uint64, error) {
   272  	pref := fmt.Sprintf("%s.%s", strings.ToLower(db), strings.ToLower(table))
   273  	var cnt uint64
   274  	for key, stats := range s.colStats {
   275  		if strings.HasPrefix(string(key), pref) {
   276  			if stats.RowCount() > cnt {
   277  				cnt = stats.RowCount()
   278  			}
   279  		}
   280  	}
   281  	return cnt, nil
   282  }
   283  
   284  func (s *StatsProv) DataLength(ctx *sql.Context, db, table string) (uint64, error) {
   285  	pref := fmt.Sprintf("%s.%s", db, table)
   286  	var size uint64
   287  	for key, stats := range s.colStats {
   288  		if strings.HasPrefix(string(key), pref) {
   289  			if stats.AvgSize() > size {
   290  				size = stats.AvgSize()
   291  			}
   292  		}
   293  	}
   294  	return size, nil
   295  }
   296  
   297  func (s *StatsProv) DropDbStats(ctx *sql.Context, db string, flush bool) error {
   298  	for key, _ := range s.colStats {
   299  		if strings.HasPrefix(string(key), db) {
   300  			delete(s.colStats, key)
   301  		}
   302  	}
   303  	return nil
   304  }