github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/frame_gen.go (about) 1 package support 2 3 import ( 4 "fmt" 5 "io" 6 "math" 7 "strings" 8 ) 9 10 //go:generate stringer -type=frameExtent 11 12 type frameExtent int 13 14 const ( 15 unboundedPreceding frameExtent = iota 16 startNPreceding 17 startCurrentRow 18 startNFollowing 19 unknown 20 endNPreceding 21 endCurrentRow 22 endNFollowing 23 unboundedFollowing 24 ) 25 26 var frameExtents = []frameExtent{ 27 unboundedPreceding, 28 startNPreceding, 29 startCurrentRow, 30 startNFollowing, 31 endNPreceding, 32 endCurrentRow, 33 endNFollowing, 34 unboundedFollowing, 35 } 36 37 func (e frameExtent) argType() string { 38 switch e { 39 case unboundedPreceding, startCurrentRow, endCurrentRow, unboundedFollowing: 40 return "bool" 41 case startNPreceding, startNFollowing, endNPreceding, endNFollowing: 42 return "sql.Expression" 43 } 44 panic(fmt.Sprintf("invalid frameExtent: %v", e)) 45 } 46 47 func (e frameExtent) Arg() map[string]string { 48 return map[string]string{e.String(): e.argType()} 49 } 50 51 func (e frameExtent) cond() string { 52 switch e { 53 case unboundedPreceding, startCurrentRow, endCurrentRow, unboundedFollowing: 54 return fmt.Sprintf("%s", e.String()) 55 case startNPreceding, startNFollowing, endNPreceding, endNFollowing: 56 return fmt.Sprintf("%s != nil", e.String()) 57 } 58 panic(fmt.Sprintf("invalid frameExtent: %v", e)) 59 } 60 61 type frameUnit int 62 63 const ( 64 rows frameUnit = iota 65 rang 66 ) 67 68 var frameUnits = []frameUnit{rows, rang} 69 70 func (b frameUnit) String() string { 71 switch b { 72 case rows: 73 return "Rows" 74 case rang: 75 return "Range" 76 } 77 return "" 78 } 79 80 type frameBound int 81 82 const ( 83 startBound frameBound = iota 84 endBound 85 ) 86 87 var implicitRightBound = []frameBound{startBound, endBound} 88 89 func (b frameBound) String() string { 90 switch b { 91 case startBound: 92 return "Start" 93 case endBound: 94 return "End" 95 } 96 return "" 97 } 98 99 type frameDef struct { 100 start frameExtent 101 end frameExtent 102 unit frameUnit 103 op int 104 } 105 106 func (d *frameDef) Name() string { 107 start := strings.ReplaceAll(strings.Title(d.start.String()), startBound.String(), "") 108 end := strings.ReplaceAll(strings.Title(d.end.String()), endBound.String(), "") 109 return fmt.Sprintf("%s%sTo%s", d.unit, start, end) 110 } 111 112 func (d *frameDef) OpName() string { 113 return fmt.Sprintf("%sTo%s", d.start.String(), d.end.String()) 114 } 115 116 func (d *frameDef) valid() bool { 117 switch { 118 case d.end == unknown || d.start == unknown: 119 return false 120 case d.end < d.start: 121 return false 122 case d.end < unknown: 123 return false 124 case d.start > unknown: 125 return false 126 } 127 return true 128 } 129 130 func (d *frameDef) Args() []frameExtent { 131 return []frameExtent{d.start, d.end} 132 } 133 134 func (d *frameDef) CondArgs() string { 135 return fmt.Sprintf("is%s && %s && %s", d.unit, d.start.cond(), d.end.cond()) 136 } 137 138 func (d *frameDef) SigArgs() string { 139 sb := strings.Builder{} 140 i := 0 141 for _, a := range d.Args() { 142 if a.argType() == "bool" { 143 continue 144 } 145 if i > 0 { 146 sb.WriteString(", ") 147 } 148 sb.WriteString(fmt.Sprintf("%s %s", a, a.argType())) 149 i++ 150 } 151 return sb.String() 152 } 153 154 type FrameGen struct { 155 w io.Writer 156 defs []frameDef 157 limit int 158 } 159 160 func (g *FrameGen) Generate(defines GenDefs, w io.Writer) { 161 g.w = w 162 if g.limit == 0 { 163 g.limit = math.MaxInt32 164 } 165 g.defs = getDefs(g.limit) 166 g.generate() 167 } 168 169 func getDefs(limit int) []frameDef { 170 i := 0 171 defs := make([]frameDef, 0) 172 for _, unit := range frameUnits { 173 for _, start := range frameExtents { 174 for _, end := range frameExtents { 175 def := frameDef{unit: unit, start: start, end: end, op: i} 176 if !def.valid() { 177 continue 178 } 179 if i >= limit { 180 return defs 181 } 182 defs = append(defs, def) 183 i++ 184 } 185 } 186 } 187 return defs 188 } 189 190 func (g *FrameGen) generate() { 191 g.genImports() 192 for _, def := range g.defs { 193 g.genFrameType(def) 194 g.genNewFrame(def) 195 g.genFrameAccessors(def) 196 g.genNewFramer(def) 197 } 198 } 199 200 func (g *FrameGen) genImports() { 201 fmt.Fprintf(g.w, "import (\n") 202 fmt.Fprintf(g.w, " \"github.com/dolthub/go-mysql-server/sql\"\n") 203 fmt.Fprintf(g.w, " agg \"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation\"\n") 204 fmt.Fprintf(g.w, ")\n\n") 205 } 206 207 func (g *FrameGen) genFrameType(def frameDef) { 208 fmt.Fprintf(g.w, "type %sFrame struct {\n", def.Name()) 209 fmt.Fprintf(g.w, " windowFrameBase\n") 210 fmt.Fprintf(g.w, "}\n\n") 211 212 fmt.Fprintf(g.w, "var _ sql.WindowFrame = (*%sFrame)(nil)\n\n", def.Name()) 213 214 } 215 216 func (g *FrameGen) genNewFrame(def frameDef) { 217 fmt.Fprintf(g.w, "func New%sFrame(%s) *%sFrame {\n", def.Name(), def.SigArgs(), def.Name()) 218 fmt.Fprintf(g.w, " return &%sFrame{\n", def.Name()) 219 fmt.Fprintf(g.w, " windowFrameBase{\n") 220 switch def.unit { 221 case rows: 222 fmt.Fprintf(g.w, " isRows: true,\n") 223 case rang: 224 fmt.Fprintf(g.w, " isRange: true,\n") 225 } 226 227 for _, a := range def.Args() { 228 switch a.argType() { 229 case "sql.Expression": 230 fmt.Fprintf(g.w, " %s: %s,\n", a, a) 231 case "bool": 232 fmt.Fprintf(g.w, " %s: true,\n", a) 233 } 234 } 235 236 fmt.Fprintf(g.w, " },\n") 237 fmt.Fprintf(g.w, " }\n") 238 fmt.Fprintf(g.w, "}\n\n") 239 } 240 241 func (g *FrameGen) genFrameAccessors(def frameDef) { 242 for _, e := range frameExtents { 243 fmt.Fprintf(g.w, "func (f *%sFrame) %s() %s {\n", def.Name(), strings.Title(e.String()), e.argType()) 244 fmt.Fprintf(g.w, " return f.%s\n", e) 245 fmt.Fprintf(g.w, "}\n\n") 246 } 247 } 248 249 func (g *FrameGen) genNewFramer(def frameDef) { 250 fmt.Fprintf(g.w, "func (f *%sFrame) NewFramer(w *sql.WindowDefinition) (sql.WindowFramer, error) {\n", def.Name()) 251 fmt.Fprintf(g.w, " return agg.New%sFramer(f, w)\n", def.Name()) 252 fmt.Fprintf(g.w, "}\n\n") 253 }