github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/memo_gen.go (about) 1 package support 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "os" 8 "strings" 9 10 "gopkg.in/yaml.v3" 11 ) 12 13 //go:generate go run ../optgen/main.go -out ../../../sql/memo/memo.og.go -pkg memo memo 14 15 type MemoExprs struct { 16 Exprs []ExprDef `yaml:"exprs"` 17 } 18 19 type ExprDef struct { 20 Name string `yaml:"name"` 21 SourceType string `yaml:"sourceType"` 22 Join bool `yaml:"join"` 23 Attrs [][2]string `yaml:"attrs"` 24 Unary bool `yaml:"unary"` 25 SkipExec bool `yaml:"skipExec"` 26 Binary bool `yaml:"binary"` 27 SkipName bool `yaml:"skipName"` 28 SkipTableId bool `yaml:"skipTableId"` 29 } 30 31 func DecodeMemoExprs(path string) (MemoExprs, error) { 32 contents, err := os.ReadFile(path) 33 if err != nil { 34 return MemoExprs{}, err 35 } 36 dec := yaml.NewDecoder(bytes.NewReader(contents)) 37 dec.KnownFields(true) 38 var res MemoExprs 39 return res, dec.Decode(&res) 40 } 41 42 var _ GenDefs = (*MemoExprs)(nil) 43 44 type MemoGen struct { 45 defines []ExprDef 46 w io.Writer 47 } 48 49 func (g *MemoGen) Generate(defines GenDefs, w io.Writer) { 50 g.defines = defines.(MemoExprs).Exprs 51 52 g.w = w 53 54 g.genImport() 55 for _, define := range g.defines { 56 g.genType(define) 57 g.genRelInterfaces(define) 58 59 g.genStringer(define) 60 if define.SourceType != "" { 61 g.genSourceRelInterface(define) 62 } 63 if define.Join { 64 g.genJoinRelInterface(define) 65 } else if define.Binary { 66 g.genBinaryGroupInterface(define) 67 } else if define.Unary { 68 g.genUnaryGroupInterface(define) 69 } else { 70 g.genChildlessGroupInterface(define) 71 } 72 } 73 g.genFormatters(g.defines) 74 75 } 76 77 func (g *MemoGen) genImport() { 78 fmt.Fprintf(g.w, "import (\n") 79 fmt.Fprintf(g.w, " \"fmt\"\n") 80 fmt.Fprintf(g.w, " \"strings\"\n") 81 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql\"\n") 82 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/plan\"\n") 83 fmt.Fprintf(g.w, ")\n\n") 84 } 85 86 func (g *MemoGen) genType(define ExprDef) { 87 fmt.Fprintf(g.w, "type %s struct {\n", strings.Title(define.Name)) 88 if define.SourceType != "" { 89 fmt.Fprintf(g.w, " *sourceBase\n") 90 fmt.Fprintf(g.w, " Table %s\n", define.SourceType) 91 } else if define.Join { 92 fmt.Fprintf(g.w, " *JoinBase\n") 93 } else if define.Unary { 94 fmt.Fprintf(g.w, " *relBase\n") 95 fmt.Fprintf(g.w, " Child *ExprGroup\n") 96 } else if define.Binary { 97 fmt.Fprintf(g.w, " *relBase\n") 98 fmt.Fprintf(g.w, " Left *ExprGroup\n") 99 fmt.Fprintf(g.w, " Right *ExprGroup\n") 100 } 101 for _, attr := range define.Attrs { 102 fmt.Fprintf(g.w, " %s %s\n", strings.Title(attr[0]), attr[1]) 103 } 104 105 fmt.Fprintf(g.w, "}\n\n") 106 } 107 108 func (g *MemoGen) genRelInterfaces(define ExprDef) { 109 fmt.Fprintf(g.w, "var _ RelExpr = (*%s)(nil)\n", define.Name) 110 if define.SourceType != "" { 111 fmt.Fprintf(g.w, "var _ SourceRel = (*%s)(nil)\n", define.Name) 112 } else if define.Join { 113 fmt.Fprintf(g.w, "var _ JoinRel = (*%s)(nil)\n", define.Name) 114 } else if define.Unary || define.Binary { 115 } else { 116 panic("unreachable") 117 } 118 fmt.Fprintf(g.w, "\n") 119 } 120 121 func (g *MemoGen) genScalarInterfaces(define ExprDef) { 122 fmt.Fprintf(g.w, "var _ ScalarExpr = (*%s)(nil)\n", define.Name) 123 124 fmt.Fprintf(g.w, "\n") 125 126 fmt.Fprintf(g.w, "func (r *%s) ExprId() ScalarExprId {\n", define.Name) 127 fmt.Fprintf(g.w, " return ScalarExpr%s\n", strings.Title(define.Name)) 128 fmt.Fprintf(g.w, "}\n\n") 129 } 130 131 func (g *MemoGen) genStringer(define ExprDef) { 132 fmt.Fprintf(g.w, "func (r *%s) String() string {\n", define.Name) 133 fmt.Fprintf(g.w, " return FormatExpr(r)\n") 134 fmt.Fprintf(g.w, "}\n\n") 135 } 136 137 func (g *MemoGen) genSourceRelInterface(define ExprDef) { 138 fmt.Fprintf(g.w, "func (r *%s) Name() string {\n", define.Name) 139 if !define.SkipName { 140 fmt.Fprintf(g.w, " return strings.ToLower(r.Table.Name())\n") 141 } else { 142 fmt.Fprintf(g.w, " return \"\"\n") 143 } 144 fmt.Fprintf(g.w, "}\n\n") 145 146 fmt.Fprintf(g.w, "func (r *%s) TableId() sql.TableId {\n", define.Name) 147 fmt.Fprintf(g.w, " return TableIdForSource(r.g.Id)\n") 148 fmt.Fprintf(g.w, "}\n\n") 149 150 fmt.Fprintf(g.w, "func (r *%s) TableIdNode() plan.TableIdNode {\n", define.Name) 151 if define.SkipTableId { 152 fmt.Fprintf(g.w, " return nil\n") 153 } else { 154 fmt.Fprintf(g.w, " return r.Table\n") 155 } 156 fmt.Fprintf(g.w, "}\n\n") 157 158 fmt.Fprintf(g.w, "func (r *%s) OutputCols() sql.Schema {\n", define.Name) 159 fmt.Fprintf(g.w, " return r.Table.Schema()\n") 160 fmt.Fprintf(g.w, "}\n\n") 161 } 162 163 func (g *MemoGen) genJoinRelInterface(define ExprDef) { 164 fmt.Fprintf(g.w, "func (r *%s) JoinPrivate() *JoinBase {\n", define.Name) 165 fmt.Fprintf(g.w, " return r.JoinBase\n") 166 fmt.Fprintf(g.w, "}\n\n") 167 } 168 169 func (g *MemoGen) genBinaryGroupInterface(define ExprDef) { 170 fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name) 171 fmt.Fprintf(g.w, " return []*ExprGroup{r.Left, r.Right}\n") 172 fmt.Fprintf(g.w, "}\n\n") 173 } 174 175 func (g *MemoGen) genChildlessGroupInterface(define ExprDef) { 176 fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name) 177 fmt.Fprintf(g.w, " return nil\n") 178 fmt.Fprintf(g.w, "}\n\n") 179 } 180 181 func (g *MemoGen) genUnaryGroupInterface(define ExprDef) { 182 fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name) 183 fmt.Fprintf(g.w, " return []*ExprGroup{r.Child}\n") 184 fmt.Fprintf(g.w, "}\n\n") 185 186 fmt.Fprintf(g.w, "func (r *%s) outputCols() sql.ColSet {\n", define.Name) 187 switch define.Name { 188 case "Project": 189 fmt.Fprintf(g.w, " return getProjectColset(r)\n") 190 191 default: 192 fmt.Fprintf(g.w, " return r.Child.RelProps.OutputCols()\n") 193 } 194 195 fmt.Fprintf(g.w, "}\n\n") 196 197 } 198 199 func (g *MemoGen) genFormatters(defines []ExprDef) { 200 // printer 201 fmt.Fprintf(g.w, "func FormatExpr(r exprType) string {\n") 202 fmt.Fprintf(g.w, " switch r := r.(type) {\n") 203 for _, d := range defines { 204 loweredName := strings.ToLower(d.Name) 205 fmt.Fprintf(g.w, " case *%s:\n", d.Name) 206 if loweredName == "indexscan" { 207 fmt.Fprintf(g.w, " if r.Alias != \"\" {\n") 208 fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s: %%s\", r.Alias)\n", loweredName) 209 fmt.Fprintf(g.w, " }\n") 210 } 211 if d.SourceType != "" { 212 fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s: %%s\", r.Name())\n", loweredName) 213 } else if d.Join || d.Binary { 214 fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s %%d %%d\", r.Left.Id, r.Right.Id)\n", loweredName) 215 } else if d.Unary { 216 fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s: %%d\", r.Child.Id)\n", loweredName) 217 } else { 218 panic("unreachable") 219 } 220 } 221 fmt.Fprintf(g.w, " default:\n") 222 fmt.Fprintf(g.w, " panic(fmt.Sprintf(\"unknown RelExpr type: %%T\", r))\n") 223 fmt.Fprintf(g.w, " }\n") 224 fmt.Fprintf(g.w, "}\n\n") 225 226 // to sqlNode 227 fmt.Fprintf(g.w, "func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) {\n") 228 fmt.Fprintf(g.w, " var result sql.Node\n") 229 fmt.Fprintf(g.w, " var err error\n\n") 230 fmt.Fprintf(g.w, " switch r := r.(type) {\n") 231 for _, d := range defines { 232 if d.SkipExec { 233 continue 234 } 235 fmt.Fprintf(g.w, " case *%s:\n", d.Name) 236 fmt.Fprintf(g.w, " result, err = b.build%s(r, children...)\n", strings.Title(d.Name)) 237 } 238 fmt.Fprintf(g.w, " default:\n") 239 fmt.Fprintf(g.w, " panic(fmt.Sprintf(\"unknown RelExpr type: %%T\", r))\n") 240 fmt.Fprintf(g.w, " }\n\n") 241 fmt.Fprintf(g.w, " if err != nil {\n") 242 fmt.Fprintf(g.w, " return nil, err\n") 243 fmt.Fprintf(g.w, " }\n\n") 244 fmt.Fprintf(g.w, "if withDescribeStats, ok := result.(sql.WithDescribeStats); ok {\n") 245 fmt.Fprintf(g.w, " withDescribeStats.SetDescribeStats(*DescribeStats(r))\n") 246 fmt.Fprintf(g.w, "}\n") 247 fmt.Fprintf(g.w, " result, err = r.Group().finalize(result)\n") 248 fmt.Fprintf(g.w, " if err != nil {\n") 249 fmt.Fprintf(g.w, " return nil, err\n") 250 fmt.Fprintf(g.w, " }\n") 251 fmt.Fprintf(g.w, " return result, nil\n") 252 fmt.Fprintf(g.w, "}\n\n") 253 }