github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/framer_gen.go (about) 1 package support 2 3 import ( 4 "fmt" 5 "io" 6 "math" 7 "strings" 8 ) 9 10 type FramerGen struct { 11 w io.Writer 12 defs []frameDef 13 limit int 14 } 15 16 func (g *FramerGen) Generate(defines GenDefs, w io.Writer) { 17 g.w = w 18 if g.limit == 0 { 19 g.limit = math.MaxInt32 20 } 21 g.defs = getDefs(g.limit) 22 g.generate() 23 } 24 25 func (g *FramerGen) generate() { 26 g.genImports() 27 for _, def := range g.defs { 28 g.genFramerType(def) 29 g.genNewFramer(def) 30 } 31 } 32 33 func (g *FramerGen) genImports() { 34 fmt.Fprintf(g.w, "import (\n") 35 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql\"\n") 36 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql/expression\"\n") 37 fmt.Fprintf(g.w, ")\n\n") 38 } 39 40 func (g *FramerGen) genFramerType(def frameDef) { 41 fmt.Fprintf(g.w, "type %sFramer struct {\n", def.Name()) 42 switch def.unit { 43 case rows: 44 fmt.Fprintf(g.w, " rowFramerBase\n") 45 case rang: 46 fmt.Fprintf(g.w, " rangeFramerBase\n") 47 } 48 fmt.Fprintf(g.w, "}\n\n") 49 50 fmt.Fprintf(g.w, "var _ sql.WindowFramer = (*%sFramer)(nil)\n\n", def.Name()) 51 52 } 53 54 func (g *FramerGen) genNewFramer(def frameDef) { 55 framerName := fmt.Sprintf("%sFramer", def.Name()) 56 fmt.Fprintf(g.w, "func New%sFramer(frame sql.WindowFrame, window *sql.WindowDefinition) (sql.WindowFramer, error) {\n", def.Name()) 57 58 for _, a := range def.Args() { 59 switch a.argType() { 60 case "sql.Expression": 61 switch def.unit { 62 case rows: 63 fmt.Fprintf(g.w, " %s, err := expression.LiteralToInt(frame.%s())\n", a, strings.Title(a.String())) 64 fmt.Fprintf(g.w, " if err != nil {\n") 65 fmt.Fprintf(g.w, " return nil, err\n") 66 fmt.Fprintf(g.w, " }\n") 67 case rang: 68 fmt.Fprintf(g.w, " %s := frame.%s()\n", a, strings.Title(a.String())) 69 } 70 case "bool": 71 fmt.Fprintf(g.w, " %s := true\n", a) 72 } 73 } 74 75 orderByRequired := def.unit == rang && 76 ((def.start != unboundedPreceding && def.start != startCurrentRow) || 77 (def.end != unboundedFollowing && def.end != endCurrentRow)) 78 79 if orderByRequired { 80 fmt.Fprintf(g.w, " if len(window.OrderBy) != 1 {\n") 81 fmt.Fprintf(g.w, " return nil, ErrRangeInvalidOrderBy.New(len(window.OrderBy.ToExpressions()))\n") 82 fmt.Fprintf(g.w, " }\n") 83 } 84 85 if def.unit == rang { 86 fmt.Fprintf(g.w, " var orderBy sql.Expression\n") 87 fmt.Fprintf(g.w, " if len(window.OrderBy) > 0 {\n") 88 fmt.Fprintf(g.w, " orderBy = window.OrderBy.ToExpressions()[0]\n") 89 fmt.Fprintf(g.w, " }\n") 90 } 91 92 fmt.Fprintf(g.w, " return &%s{\n", framerName) 93 switch def.unit { 94 case rows: 95 fmt.Fprintf(g.w, " rowFramerBase{\n") 96 case rang: 97 fmt.Fprintf(g.w, " rangeFramerBase{\n") 98 fmt.Fprintf(g.w, " orderBy: orderBy,\n") 99 } 100 101 for _, a := range def.Args() { 102 fmt.Fprintf(g.w, " %s: %s,\n", a, a) 103 } 104 105 fmt.Fprintf(g.w, " },\n") 106 fmt.Fprintf(g.w, " }, nil\n") 107 fmt.Fprintf(g.w, "}\n\n") 108 }