github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/agg_gen.go (about)

     1  package support
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"strings"
     9  
    10  	"gopkg.in/yaml.v3"
    11  )
    12  
    13  type AggDefs struct {
    14  	UnaryAggs []AggDef `yaml:"unaryAggs"`
    15  }
    16  
    17  type AggDef struct {
    18  	Name     string `yaml:"name"`
    19  	SqlName  string `yaml:"sqlName"`
    20  	Desc     string `yaml:"desc"`
    21  	RetType  string `yaml:"retType"` // must be valid sql.Type
    22  	Nullable bool   `yaml:"nullable"`
    23  }
    24  
    25  var _ GenDefs = ([]AggDef)(nil)
    26  
    27  type AggGen struct {
    28  	defines []AggDef
    29  	w       io.Writer
    30  }
    31  
    32  func DecodeUnaryAggDefs(path string) (AggDefs, error) {
    33  	contents, err := os.ReadFile(path)
    34  	if err != nil {
    35  		return AggDefs{}, err
    36  	}
    37  	dec := yaml.NewDecoder(bytes.NewReader(contents))
    38  	dec.KnownFields(true)
    39  	var res AggDefs
    40  	return res, dec.Decode(&res)
    41  }
    42  
    43  func (g *AggGen) Generate(defines GenDefs, w io.Writer) {
    44  	g.defines = defines.(AggDefs).UnaryAggs
    45  
    46  	g.w = w
    47  
    48  	fmt.Fprintf(g.w, "import (\n")
    49  	fmt.Fprintf(g.w, "    \"fmt\"\n")
    50  	fmt.Fprintf(g.w, "    \"github.com/dolthub/go-mysql-server/sql/types\"\n")
    51  	fmt.Fprintf(g.w, "    \"github.com/dolthub/go-mysql-server/sql\"\n")
    52  	fmt.Fprintf(g.w, "    \"github.com/dolthub/go-mysql-server/sql/expression\"\n")
    53  	fmt.Fprintf(g.w, "    \"github.com/dolthub/go-mysql-server/sql/transform\"\n")
    54  	fmt.Fprintf(g.w, ")\n\n")
    55  
    56  	for _, define := range g.defines {
    57  		g.genAggType(define)
    58  		g.genAggInterfaces(define)
    59  		g.genAggConstructor(define)
    60  		g.genAggPropAccessors(define)
    61  		g.genAggStringer(define)
    62  		g.genAggWithWindow(define)
    63  		g.genAggWithChildren(define)
    64  		g.genAggWithId(define)
    65  		g.genAggNewBuffer(define)
    66  		g.genAggWindowConstructor(define)
    67  	}
    68  }
    69  
    70  func (g *AggGen) genAggType(define AggDef) {
    71  	fmt.Fprintf(g.w, "type %s struct{\n", define.Name)
    72  	fmt.Fprintf(g.w, "    unaryAggBase\n")
    73  	fmt.Fprintf(g.w, "}\n\n")
    74  }
    75  
    76  func (g *AggGen) genAggInterfaces(define AggDef) {
    77  	fmt.Fprintf(g.w, "var _ sql.FunctionExpression = (*%s)(nil)\n", define.Name)
    78  	fmt.Fprintf(g.w, "var _ sql.Aggregation = (*%s)(nil)\n", define.Name)
    79  	fmt.Fprintf(g.w, "var _ sql.WindowAdaptableExpression = (*%s)(nil)\n", define.Name)
    80  	fmt.Fprintf(g.w, "\n")
    81  
    82  }
    83  
    84  func (g *AggGen) genAggConstructor(define AggDef) {
    85  	fmt.Fprintf(g.w, "func New%s(e sql.Expression) *%s {\n", define.Name, define.Name)
    86  	fmt.Fprintf(g.w, "    return &%s{\n", define.Name)
    87  	fmt.Fprintf(g.w, "        unaryAggBase{\n")
    88  	fmt.Fprintf(g.w, "            UnaryExpression: expression.UnaryExpression{Child: e},\n")
    89  	fmt.Fprintf(g.w, "            functionName: \"%s\",\n", define.Name)
    90  	fmt.Fprintf(g.w, "            description: \"%s\",\n", define.Desc)
    91  	fmt.Fprintf(g.w, "        },\n")
    92  	fmt.Fprintf(g.w, "    }\n")
    93  	fmt.Fprintf(g.w, "}\n\n")
    94  }
    95  
    96  func (g *AggGen) genAggPropAccessors(define AggDef) {
    97  	retType := "a.Child.Type()"
    98  	if define.RetType != "" {
    99  		retType = define.RetType
   100  	}
   101  	fmt.Fprintf(g.w, "func (a *%s) Type() sql.Type {\n", define.Name)
   102  	fmt.Fprintf(g.w, "    return %s\n", retType)
   103  	fmt.Fprintf(g.w, "}\n\n")
   104  
   105  	fmt.Fprintf(g.w, "func (a *%s) IsNullable() bool {\n", define.Name)
   106  	fmt.Fprintf(g.w, "    return %t\n", define.Nullable)
   107  	fmt.Fprintf(g.w, "}\n\n")
   108  }
   109  
   110  func (g *AggGen) genAggStringer(define AggDef) {
   111  	sqlName := define.Name
   112  	if define.SqlName != "" {
   113  		sqlName = define.SqlName
   114  	}
   115  	fmt.Fprintf(g.w, "func (a *%s) String() string {\n", define.Name)
   116  	fmt.Fprintf(g.w, "  if a.window != nil {\n")
   117  	fmt.Fprintf(g.w, "    pr := sql.NewTreePrinter()\n")
   118  	fmt.Fprintf(g.w, "    _ = pr.WriteNode(\"%s\")\n	", strings.ToUpper(sqlName))
   119  	fmt.Fprintf(g.w, "    children := []string{a.window.String(), a.Child.String()}\n")
   120  	fmt.Fprintf(g.w, "    pr.WriteChildren(children...)\n")
   121  	fmt.Fprintf(g.w, "    return pr.String()\n")
   122  	fmt.Fprintf(g.w, "  }\n")
   123  	fmt.Fprintf(g.w, "  return fmt.Sprintf(\"%s(%%s)\", a.Child)\n", strings.ToUpper(sqlName))
   124  	fmt.Fprintf(g.w, "}\n\n")
   125  
   126  	fmt.Fprintf(g.w, "func (a *%s) DebugString() string {\n", define.Name)
   127  	fmt.Fprintf(g.w, "  if a.window != nil {\n")
   128  	fmt.Fprintf(g.w, "    pr := sql.NewTreePrinter()\n")
   129  	fmt.Fprintf(g.w, "    _ = pr.WriteNode(\"%s\")\n	", strings.ToUpper(sqlName))
   130  	fmt.Fprintf(g.w, "    children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)}\n")
   131  	fmt.Fprintf(g.w, "    pr.WriteChildren(children...)\n")
   132  	fmt.Fprintf(g.w, "    return pr.String()\n")
   133  	fmt.Fprintf(g.w, "  }\n")
   134  	fmt.Fprintf(g.w, "  return fmt.Sprintf(\"%s(%%s)\", sql.DebugString(a.Child))\n", strings.ToUpper(sqlName))
   135  	fmt.Fprintf(g.w, "}\n\n")
   136  }
   137  
   138  func (g *AggGen) genAggWithChildren(define AggDef) {
   139  	fmt.Fprintf(g.w, "func (a *%s) WithChildren(children ...sql.Expression) (sql.Expression, error) {\n", define.Name)
   140  	fmt.Fprintf(g.w, "    res, err := a.unaryAggBase.WithChildren(children...)\n")
   141  	fmt.Fprintf(g.w, "    return &%s{unaryAggBase: *res.(*unaryAggBase)}, err\n", define.Name)
   142  	fmt.Fprintf(g.w, "}\n\n")
   143  }
   144  
   145  func (g *AggGen) genAggWithId(define AggDef) {
   146  	fmt.Fprintf(g.w, "func (a *%s) WithId(id sql.ColumnId) sql.IdExpression {\n", define.Name)
   147  	fmt.Fprintf(g.w, "    res := a.unaryAggBase.WithId(id)\n")
   148  	fmt.Fprintf(g.w, "    return &%s{unaryAggBase: *res.(*unaryAggBase)}\n", define.Name)
   149  	fmt.Fprintf(g.w, "}\n\n")
   150  }
   151  
   152  func (g *AggGen) genAggWithWindow(define AggDef) {
   153  	fmt.Fprintf(g.w, "func (a *%s) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression {\n", define.Name)
   154  	fmt.Fprintf(g.w, "    res := a.unaryAggBase.WithWindow(window)\n")
   155  	fmt.Fprintf(g.w, "    return &%s{unaryAggBase: *res.(*unaryAggBase)}\n", define.Name)
   156  	fmt.Fprintf(g.w, "}\n\n")
   157  }
   158  
   159  func (g *AggGen) genAggWindowConstructor(define AggDef) {
   160  	fmt.Fprintf(g.w, "func (a *%s) NewWindowFunction() (sql.WindowFunction, error) {\n", define.Name)
   161  	fmt.Fprintf(g.w, "    child, err := transform.Clone(a.Child)\n")
   162  	fmt.Fprintf(g.w, "    if err != nil {\n")
   163  	fmt.Fprintf(g.w, "        return nil, err\n")
   164  	fmt.Fprintf(g.w, "    }\n")
   165  	fmt.Fprintf(g.w, "    return New%sAgg(child).WithWindow(a.Window())\n", define.Name)
   166  	fmt.Fprintf(g.w, "}\n\n")
   167  }
   168  
   169  func (g *AggGen) genAggNewBuffer(define AggDef) {
   170  	fmt.Fprintf(g.w, "func (a *%s) NewBuffer() (sql.AggregationBuffer, error) {\n", define.Name)
   171  	fmt.Fprintf(g.w, "    child, err := transform.Clone(a.Child)\n")
   172  	fmt.Fprintf(g.w, "    if err != nil {\n")
   173  	fmt.Fprintf(g.w, "        return nil, err\n")
   174  	fmt.Fprintf(g.w, "    }\n")
   175  	fmt.Fprintf(g.w, "    return New%sBuffer(child), nil\n", define.Name)
   176  	fmt.Fprintf(g.w, "}\n\n")
   177  }