github.com/dolthub/go-mysql-server@v0.18.0/memory/exponential_dist_table.go (about)

     1  package memory
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/dolthub/go-mysql-server/sql"
     7  	"github.com/dolthub/go-mysql-server/sql/expression"
     8  	"github.com/dolthub/go-mysql-server/sql/stats"
     9  	"github.com/dolthub/go-mysql-server/sql/types"
    10  )
    11  
    12  var _ sql.TableFunction = ExponentialDistTable{}
    13  var _ sql.CollationCoercible = ExponentialDistTable{}
    14  var _ sql.ExecSourceRel = ExponentialDistTable{}
    15  var _ sql.TableNode = ExponentialDistTable{}
    16  
    17  // ExponentialDistTable a simple table function that returns samples
    18  // from a parameterized exponential distribution.
    19  type ExponentialDistTable struct {
    20  	db     sql.Database
    21  	name   string
    22  	rowCnt int
    23  	colCnt int
    24  	lambda float64
    25  }
    26  
    27  func (s ExponentialDistTable) UnderlyingTable() sql.Table {
    28  	return s
    29  }
    30  
    31  func (s ExponentialDistTable) NewInstance(_ *sql.Context, db sql.Database, args []sql.Expression) (sql.Node, error) {
    32  	if len(args) != 3 {
    33  		return nil, fmt.Errorf("exponential_dist table expects 2 arguments: (cols, rows, lambda)")
    34  	}
    35  	colCntLit, ok := args[0].(*expression.Literal)
    36  	if !ok {
    37  		return nil, fmt.Errorf("normal_dist table expects arguments to be literal expressions")
    38  	}
    39  	colCnt, inBounds, _ := types.Int64.Convert(colCntLit.Value())
    40  	if !inBounds {
    41  		return nil, fmt.Errorf("normal_dist table expects 1st argument to be column count")
    42  	}
    43  	rowCntLit, ok := args[1].(*expression.Literal)
    44  	if !ok {
    45  		return nil, fmt.Errorf("normal_dist table expects arguments to be literal expressions")
    46  	}
    47  	rowCnt, inBounds, _ := types.Int64.Convert(rowCntLit.Value())
    48  	if !inBounds {
    49  		return nil, fmt.Errorf("normal_dist table expects 2nd argument to be row count")
    50  	}
    51  	lambdaLit, ok := args[2].(*expression.Literal)
    52  	if !ok {
    53  		return nil, fmt.Errorf("exponential_dist table expects arguments to be literal expressions")
    54  	}
    55  	lambda, inBounds, _ := types.Float64.Convert(lambdaLit.Value())
    56  	if !inBounds {
    57  		return nil, fmt.Errorf("exponential_dist table expects 3rd argument to be row count")
    58  	}
    59  	return ExponentialDistTable{db: db, colCnt: int(colCnt.(int64)), rowCnt: int(rowCnt.(int64)), lambda: lambda.(float64)}, nil
    60  }
    61  
    62  func (s ExponentialDistTable) Resolved() bool {
    63  	return true
    64  }
    65  
    66  func (s ExponentialDistTable) IsReadOnly() bool {
    67  	return true
    68  }
    69  
    70  func (s ExponentialDistTable) String() string {
    71  	return "normal_dist"
    72  }
    73  
    74  func (s ExponentialDistTable) DebugString() string {
    75  	pr := sql.NewTreePrinter()
    76  	_ = pr.WriteNode("normal_dist")
    77  	children := []string{
    78  		fmt.Sprintf("columns: %d", s.colCnt),
    79  		fmt.Sprintf("rows: %d", s.rowCnt),
    80  		fmt.Sprintf("lambda: %f.2", s.lambda),
    81  	}
    82  	_ = pr.WriteChildren(children...)
    83  	return pr.String()
    84  }
    85  
    86  func (s ExponentialDistTable) Schema() sql.Schema {
    87  	var sch sql.Schema
    88  	for i := 0; i < s.colCnt+1; i++ {
    89  		sch = append(sch, &sql.Column{
    90  			DatabaseSource: s.db.Name(),
    91  			Source:         s.Name(),
    92  			Name:           fmt.Sprintf("col%d", i),
    93  			Type:           types.Int64,
    94  		})
    95  	}
    96  	return sch
    97  }
    98  
    99  func (s ExponentialDistTable) Children() []sql.Node {
   100  	return []sql.Node{}
   101  }
   102  
   103  func (s ExponentialDistTable) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIter, error) {
   104  	return stats.NewExpDistIter(s.colCnt, s.rowCnt, s.lambda), nil
   105  }
   106  
   107  func (s ExponentialDistTable) WithChildren(_ ...sql.Node) (sql.Node, error) {
   108  	return s, nil
   109  }
   110  
   111  func (s ExponentialDistTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool {
   112  	return true
   113  }
   114  
   115  // CollationCoercibility implements the interface sql.CollationCoercible.
   116  func (ExponentialDistTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   117  	return sql.Collation_binary, 5
   118  }
   119  
   120  // Collation implements the sql.Table interface.
   121  func (ExponentialDistTable) Collation() sql.CollationID {
   122  	return sql.Collation_Default
   123  }
   124  
   125  func (s ExponentialDistTable) Expressions() []sql.Expression {
   126  	return []sql.Expression{}
   127  }
   128  
   129  func (s ExponentialDistTable) WithExpressions(e ...sql.Expression) (sql.Node, error) {
   130  	return s, nil
   131  }
   132  
   133  func (s ExponentialDistTable) Database() sql.Database {
   134  	return s.db
   135  }
   136  
   137  func (s ExponentialDistTable) WithDatabase(_ sql.Database) (sql.Node, error) {
   138  	return s, nil
   139  }
   140  
   141  func (s ExponentialDistTable) Name() string {
   142  	return "exponential_dist"
   143  }
   144  
   145  func (s ExponentialDistTable) Description() string {
   146  	return "exponential distribution"
   147  }
   148  
   149  var _ sql.RowIter = (*SequenceTableFnRowIter)(nil)
   150  
   151  // Partitions is a sql.Table interface function that returns a partition of the data. This data has a single partition.
   152  func (s ExponentialDistTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
   153  	return sql.PartitionsToPartitionIter(&sequencePartition{min: 0, max: int64(s.rowCnt) - 1}), nil
   154  }
   155  
   156  // PartitionRows is a sql.Table interface function that takes a partition and returns all rows in that partition.
   157  // This table has a partition for just schema changes, one for just data changes, and one for both.
   158  func (s ExponentialDistTable) PartitionRows(ctx *sql.Context, _ sql.Partition) (sql.RowIter, error) {
   159  	return stats.NewExpDistIter(s.colCnt, s.rowCnt, s.lambda), nil
   160  }