github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/agg_gen.go (about) 1 package support 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "os" 8 "strings" 9 10 "gopkg.in/yaml.v3" 11 ) 12 13 type AggDefs struct { 14 UnaryAggs []AggDef `yaml:"unaryAggs"` 15 } 16 17 type AggDef struct { 18 Name string `yaml:"name"` 19 SqlName string `yaml:"sqlName"` 20 Desc string `yaml:"desc"` 21 RetType string `yaml:"retType"` // must be valid sql.Type 22 Nullable bool `yaml:"nullable"` 23 } 24 25 var _ GenDefs = ([]AggDef)(nil) 26 27 type AggGen struct { 28 defines []AggDef 29 w io.Writer 30 } 31 32 func DecodeUnaryAggDefs(path string) (AggDefs, error) { 33 contents, err := os.ReadFile(path) 34 if err != nil { 35 return AggDefs{}, err 36 } 37 dec := yaml.NewDecoder(bytes.NewReader(contents)) 38 dec.KnownFields(true) 39 var res AggDefs 40 return res, dec.Decode(&res) 41 } 42 43 func (g *AggGen) Generate(defines GenDefs, w io.Writer) { 44 g.defines = defines.(AggDefs).UnaryAggs 45 46 g.w = w 47 48 fmt.Fprintf(g.w, "import (\n") 49 fmt.Fprintf(g.w, " \"fmt\"\n") 50 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/types\"\n") 51 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql\"\n") 52 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/expression\"\n") 53 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/transform\"\n") 54 fmt.Fprintf(g.w, ")\n\n") 55 56 for _, define := range g.defines { 57 g.genAggType(define) 58 g.genAggInterfaces(define) 59 g.genAggConstructor(define) 60 g.genAggPropAccessors(define) 61 g.genAggStringer(define) 62 g.genAggWithWindow(define) 63 g.genAggWithChildren(define) 64 g.genAggWithId(define) 65 g.genAggNewBuffer(define) 66 g.genAggWindowConstructor(define) 67 } 68 } 69 70 func (g *AggGen) genAggType(define AggDef) { 71 fmt.Fprintf(g.w, "type %s struct{\n", define.Name) 72 fmt.Fprintf(g.w, " unaryAggBase\n") 73 fmt.Fprintf(g.w, "}\n\n") 74 } 75 76 func (g *AggGen) genAggInterfaces(define AggDef) { 77 fmt.Fprintf(g.w, "var _ sql.FunctionExpression = (*%s)(nil)\n", define.Name) 78 fmt.Fprintf(g.w, "var _ sql.Aggregation = (*%s)(nil)\n", define.Name) 79 fmt.Fprintf(g.w, "var _ sql.WindowAdaptableExpression = (*%s)(nil)\n", define.Name) 80 fmt.Fprintf(g.w, "\n") 81 82 } 83 84 func (g *AggGen) genAggConstructor(define AggDef) { 85 fmt.Fprintf(g.w, "func New%s(e sql.Expression) *%s {\n", define.Name, define.Name) 86 fmt.Fprintf(g.w, " return &%s{\n", define.Name) 87 fmt.Fprintf(g.w, " unaryAggBase{\n") 88 fmt.Fprintf(g.w, " UnaryExpression: expression.UnaryExpression{Child: e},\n") 89 fmt.Fprintf(g.w, " functionName: \"%s\",\n", define.Name) 90 fmt.Fprintf(g.w, " description: \"%s\",\n", define.Desc) 91 fmt.Fprintf(g.w, " },\n") 92 fmt.Fprintf(g.w, " }\n") 93 fmt.Fprintf(g.w, "}\n\n") 94 } 95 96 func (g *AggGen) genAggPropAccessors(define AggDef) { 97 retType := "a.Child.Type()" 98 if define.RetType != "" { 99 retType = define.RetType 100 } 101 fmt.Fprintf(g.w, "func (a *%s) Type() sql.Type {\n", define.Name) 102 fmt.Fprintf(g.w, " return %s\n", retType) 103 fmt.Fprintf(g.w, "}\n\n") 104 105 fmt.Fprintf(g.w, "func (a *%s) IsNullable() bool {\n", define.Name) 106 fmt.Fprintf(g.w, " return %t\n", define.Nullable) 107 fmt.Fprintf(g.w, "}\n\n") 108 } 109 110 func (g *AggGen) genAggStringer(define AggDef) { 111 sqlName := define.Name 112 if define.SqlName != "" { 113 sqlName = define.SqlName 114 } 115 fmt.Fprintf(g.w, "func (a *%s) String() string {\n", define.Name) 116 fmt.Fprintf(g.w, " if a.window != nil {\n") 117 fmt.Fprintf(g.w, " pr := sql.NewTreePrinter()\n") 118 fmt.Fprintf(g.w, " _ = pr.WriteNode(\"%s\")\n ", strings.ToUpper(sqlName)) 119 fmt.Fprintf(g.w, " children := []string{a.window.String(), a.Child.String()}\n") 120 fmt.Fprintf(g.w, " pr.WriteChildren(children...)\n") 121 fmt.Fprintf(g.w, " return pr.String()\n") 122 fmt.Fprintf(g.w, " }\n") 123 fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s(%%s)\", a.Child)\n", strings.ToUpper(sqlName)) 124 fmt.Fprintf(g.w, "}\n\n") 125 126 fmt.Fprintf(g.w, "func (a *%s) DebugString() string {\n", define.Name) 127 fmt.Fprintf(g.w, " if a.window != nil {\n") 128 fmt.Fprintf(g.w, " pr := sql.NewTreePrinter()\n") 129 fmt.Fprintf(g.w, " _ = pr.WriteNode(\"%s\")\n ", strings.ToUpper(sqlName)) 130 fmt.Fprintf(g.w, " children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)}\n") 131 fmt.Fprintf(g.w, " pr.WriteChildren(children...)\n") 132 fmt.Fprintf(g.w, " return pr.String()\n") 133 fmt.Fprintf(g.w, " }\n") 134 fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s(%%s)\", sql.DebugString(a.Child))\n", strings.ToUpper(sqlName)) 135 fmt.Fprintf(g.w, "}\n\n") 136 } 137 138 func (g *AggGen) genAggWithChildren(define AggDef) { 139 fmt.Fprintf(g.w, "func (a *%s) WithChildren(children ...sql.Expression) (sql.Expression, error) {\n", define.Name) 140 fmt.Fprintf(g.w, " res, err := a.unaryAggBase.WithChildren(children...)\n") 141 fmt.Fprintf(g.w, " return &%s{unaryAggBase: *res.(*unaryAggBase)}, err\n", define.Name) 142 fmt.Fprintf(g.w, "}\n\n") 143 } 144 145 func (g *AggGen) genAggWithId(define AggDef) { 146 fmt.Fprintf(g.w, "func (a *%s) WithId(id sql.ColumnId) sql.IdExpression {\n", define.Name) 147 fmt.Fprintf(g.w, " res := a.unaryAggBase.WithId(id)\n") 148 fmt.Fprintf(g.w, " return &%s{unaryAggBase: *res.(*unaryAggBase)}\n", define.Name) 149 fmt.Fprintf(g.w, "}\n\n") 150 } 151 152 func (g *AggGen) genAggWithWindow(define AggDef) { 153 fmt.Fprintf(g.w, "func (a *%s) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression {\n", define.Name) 154 fmt.Fprintf(g.w, " res := a.unaryAggBase.WithWindow(window)\n") 155 fmt.Fprintf(g.w, " return &%s{unaryAggBase: *res.(*unaryAggBase)}\n", define.Name) 156 fmt.Fprintf(g.w, "}\n\n") 157 } 158 159 func (g *AggGen) genAggWindowConstructor(define AggDef) { 160 fmt.Fprintf(g.w, "func (a *%s) NewWindowFunction() (sql.WindowFunction, error) {\n", define.Name) 161 fmt.Fprintf(g.w, " child, err := transform.Clone(a.Child)\n") 162 fmt.Fprintf(g.w, " if err != nil {\n") 163 fmt.Fprintf(g.w, " return nil, err\n") 164 fmt.Fprintf(g.w, " }\n") 165 fmt.Fprintf(g.w, " return New%sAgg(child).WithWindow(a.Window())\n", define.Name) 166 fmt.Fprintf(g.w, "}\n\n") 167 } 168 169 func (g *AggGen) genAggNewBuffer(define AggDef) { 170 fmt.Fprintf(g.w, "func (a *%s) NewBuffer() (sql.AggregationBuffer, error) {\n", define.Name) 171 fmt.Fprintf(g.w, " child, err := transform.Clone(a.Child)\n") 172 fmt.Fprintf(g.w, " if err != nil {\n") 173 fmt.Fprintf(g.w, " return nil, err\n") 174 fmt.Fprintf(g.w, " }\n") 175 fmt.Fprintf(g.w, " return New%sBuffer(child), nil\n", define.Name) 176 fmt.Fprintf(g.w, "}\n\n") 177 }