github.com/dolthub/go-mysql-server@v0.18.0/sql/stats/distributions.go (about) 1 package stats 2 3 import ( 4 "io" 5 "math" 6 "math/rand" 7 8 "github.com/dolthub/go-mysql-server/sql" 9 ) 10 11 func NewNormDistIter(colCnt, rowCnt int, mean, std float64) sql.RowIter { 12 return &normDistIter{cols: colCnt, cnt: rowCnt, std: std, mean: mean} 13 } 14 15 func NewExpDistIter(colCnt, rowCnt int, lambda float64) sql.RowIter { 16 return &expDistIter{cols: colCnt, cnt: rowCnt, lambda: lambda} 17 } 18 19 type normDistIter struct { 20 i int 21 cols int 22 cnt int 23 std, mean float64 24 } 25 26 var _ sql.RowIter = (*normDistIter)(nil) 27 28 func (d *normDistIter) Next(*sql.Context) (sql.Row, error) { 29 if d.i > d.cnt { 30 return nil, io.EOF 31 } 32 d.i++ 33 var ret sql.Row 34 ret = append(ret, d.i) 35 for i := 0; i < d.cols; i++ { 36 val := rand.NormFloat64()*d.std + d.mean 37 if math.IsNaN(val) || math.IsInf(val, 0) { 38 val = math.MaxInt 39 } 40 ret = append(ret, val) 41 } 42 return ret, nil 43 } 44 45 func (d *normDistIter) Close(*sql.Context) error { 46 return nil 47 } 48 49 type expDistIter struct { 50 i int 51 cols int 52 cnt int 53 lambda float64 54 } 55 56 var _ sql.RowIter = (*expDistIter)(nil) 57 58 func (d *expDistIter) Next(*sql.Context) (sql.Row, error) { 59 if d.i > d.cnt { 60 return nil, io.EOF 61 } 62 d.i++ 63 var ret sql.Row 64 ret = append(ret, d.i) 65 for i := 0; i < d.cols; i++ { 66 val := -math.Log2(rand.NormFloat64()) / d.lambda 67 if math.IsNaN(val) || math.IsInf(val, 0) { 68 val = math.MaxInt32 69 } 70 ret = append(ret, val) 71 } 72 return ret, nil 73 } 74 75 func (d *expDistIter) Close(*sql.Context) error { 76 return nil 77 }