github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/frame_gen.go (about)

     1  package support
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"math"
     7  	"strings"
     8  )
     9  
    10  //go:generate stringer -type=frameExtent
    11  
    12  type frameExtent int
    13  
    14  const (
    15  	unboundedPreceding frameExtent = iota
    16  	startNPreceding
    17  	startCurrentRow
    18  	startNFollowing
    19  	unknown
    20  	endNPreceding
    21  	endCurrentRow
    22  	endNFollowing
    23  	unboundedFollowing
    24  )
    25  
    26  var frameExtents = []frameExtent{
    27  	unboundedPreceding,
    28  	startNPreceding,
    29  	startCurrentRow,
    30  	startNFollowing,
    31  	endNPreceding,
    32  	endCurrentRow,
    33  	endNFollowing,
    34  	unboundedFollowing,
    35  }
    36  
    37  func (e frameExtent) argType() string {
    38  	switch e {
    39  	case unboundedPreceding, startCurrentRow, endCurrentRow, unboundedFollowing:
    40  		return "bool"
    41  	case startNPreceding, startNFollowing, endNPreceding, endNFollowing:
    42  		return "sql.Expression"
    43  	}
    44  	panic(fmt.Sprintf("invalid frameExtent: %v", e))
    45  }
    46  
    47  func (e frameExtent) Arg() map[string]string {
    48  	return map[string]string{e.String(): e.argType()}
    49  }
    50  
    51  func (e frameExtent) cond() string {
    52  	switch e {
    53  	case unboundedPreceding, startCurrentRow, endCurrentRow, unboundedFollowing:
    54  		return fmt.Sprintf("%s", e.String())
    55  	case startNPreceding, startNFollowing, endNPreceding, endNFollowing:
    56  		return fmt.Sprintf("%s != nil", e.String())
    57  	}
    58  	panic(fmt.Sprintf("invalid frameExtent: %v", e))
    59  }
    60  
    61  type frameUnit int
    62  
    63  const (
    64  	rows frameUnit = iota
    65  	rang
    66  )
    67  
    68  var frameUnits = []frameUnit{rows, rang}
    69  
    70  func (b frameUnit) String() string {
    71  	switch b {
    72  	case rows:
    73  		return "Rows"
    74  	case rang:
    75  		return "Range"
    76  	}
    77  	return ""
    78  }
    79  
    80  type frameBound int
    81  
    82  const (
    83  	startBound frameBound = iota
    84  	endBound
    85  )
    86  
    87  var implicitRightBound = []frameBound{startBound, endBound}
    88  
    89  func (b frameBound) String() string {
    90  	switch b {
    91  	case startBound:
    92  		return "Start"
    93  	case endBound:
    94  		return "End"
    95  	}
    96  	return ""
    97  }
    98  
    99  type frameDef struct {
   100  	start frameExtent
   101  	end   frameExtent
   102  	unit  frameUnit
   103  	op    int
   104  }
   105  
   106  func (d *frameDef) Name() string {
   107  	start := strings.ReplaceAll(strings.Title(d.start.String()), startBound.String(), "")
   108  	end := strings.ReplaceAll(strings.Title(d.end.String()), endBound.String(), "")
   109  	return fmt.Sprintf("%s%sTo%s", d.unit, start, end)
   110  }
   111  
   112  func (d *frameDef) OpName() string {
   113  	return fmt.Sprintf("%sTo%s", d.start.String(), d.end.String())
   114  }
   115  
   116  func (d *frameDef) valid() bool {
   117  	switch {
   118  	case d.end == unknown || d.start == unknown:
   119  		return false
   120  	case d.end < d.start:
   121  		return false
   122  	case d.end < unknown:
   123  		return false
   124  	case d.start > unknown:
   125  		return false
   126  	}
   127  	return true
   128  }
   129  
   130  func (d *frameDef) Args() []frameExtent {
   131  	return []frameExtent{d.start, d.end}
   132  }
   133  
   134  func (d *frameDef) CondArgs() string {
   135  	return fmt.Sprintf("is%s && %s && %s", d.unit, d.start.cond(), d.end.cond())
   136  }
   137  
   138  func (d *frameDef) SigArgs() string {
   139  	sb := strings.Builder{}
   140  	i := 0
   141  	for _, a := range d.Args() {
   142  		if a.argType() == "bool" {
   143  			continue
   144  		}
   145  		if i > 0 {
   146  			sb.WriteString(", ")
   147  		}
   148  		sb.WriteString(fmt.Sprintf("%s %s", a, a.argType()))
   149  		i++
   150  	}
   151  	return sb.String()
   152  }
   153  
   154  type FrameGen struct {
   155  	w     io.Writer
   156  	defs  []frameDef
   157  	limit int
   158  }
   159  
   160  func (g *FrameGen) Generate(defines GenDefs, w io.Writer) {
   161  	g.w = w
   162  	if g.limit == 0 {
   163  		g.limit = math.MaxInt32
   164  	}
   165  	g.defs = getDefs(g.limit)
   166  	g.generate()
   167  }
   168  
   169  func getDefs(limit int) []frameDef {
   170  	i := 0
   171  	defs := make([]frameDef, 0)
   172  	for _, unit := range frameUnits {
   173  		for _, start := range frameExtents {
   174  			for _, end := range frameExtents {
   175  				def := frameDef{unit: unit, start: start, end: end, op: i}
   176  				if !def.valid() {
   177  					continue
   178  				}
   179  				if i >= limit {
   180  					return defs
   181  				}
   182  				defs = append(defs, def)
   183  				i++
   184  			}
   185  		}
   186  	}
   187  	return defs
   188  }
   189  
   190  func (g *FrameGen) generate() {
   191  	g.genImports()
   192  	for _, def := range g.defs {
   193  		g.genFrameType(def)
   194  		g.genNewFrame(def)
   195  		g.genFrameAccessors(def)
   196  		g.genNewFramer(def)
   197  	}
   198  }
   199  
   200  func (g *FrameGen) genImports() {
   201  	fmt.Fprintf(g.w, "import (\n")
   202  	fmt.Fprintf(g.w, "  \"github.com/dolthub/go-mysql-server/sql\"\n")
   203  	fmt.Fprintf(g.w, "  agg \"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation\"\n")
   204  	fmt.Fprintf(g.w, ")\n\n")
   205  }
   206  
   207  func (g *FrameGen) genFrameType(def frameDef) {
   208  	fmt.Fprintf(g.w, "type %sFrame struct {\n", def.Name())
   209  	fmt.Fprintf(g.w, "    windowFrameBase\n")
   210  	fmt.Fprintf(g.w, "}\n\n")
   211  
   212  	fmt.Fprintf(g.w, "var _ sql.WindowFrame = (*%sFrame)(nil)\n\n", def.Name())
   213  
   214  }
   215  
   216  func (g *FrameGen) genNewFrame(def frameDef) {
   217  	fmt.Fprintf(g.w, "func New%sFrame(%s) *%sFrame {\n", def.Name(), def.SigArgs(), def.Name())
   218  	fmt.Fprintf(g.w, "  return &%sFrame{\n", def.Name())
   219  	fmt.Fprintf(g.w, "    windowFrameBase{\n")
   220  	switch def.unit {
   221  	case rows:
   222  		fmt.Fprintf(g.w, "      isRows: true,\n")
   223  	case rang:
   224  		fmt.Fprintf(g.w, "      isRange: true,\n")
   225  	}
   226  
   227  	for _, a := range def.Args() {
   228  		switch a.argType() {
   229  		case "sql.Expression":
   230  			fmt.Fprintf(g.w, "      %s: %s,\n", a, a)
   231  		case "bool":
   232  			fmt.Fprintf(g.w, "      %s: true,\n", a)
   233  		}
   234  	}
   235  
   236  	fmt.Fprintf(g.w, "    },\n")
   237  	fmt.Fprintf(g.w, "  }\n")
   238  	fmt.Fprintf(g.w, "}\n\n")
   239  }
   240  
   241  func (g *FrameGen) genFrameAccessors(def frameDef) {
   242  	for _, e := range frameExtents {
   243  		fmt.Fprintf(g.w, "func (f *%sFrame) %s() %s {\n", def.Name(), strings.Title(e.String()), e.argType())
   244  		fmt.Fprintf(g.w, "  return f.%s\n", e)
   245  		fmt.Fprintf(g.w, "}\n\n")
   246  	}
   247  }
   248  
   249  func (g *FrameGen) genNewFramer(def frameDef) {
   250  	fmt.Fprintf(g.w, "func (f *%sFrame) NewFramer(w *sql.WindowDefinition) (sql.WindowFramer, error) {\n", def.Name())
   251  	fmt.Fprintf(g.w, "    return agg.New%sFramer(f, w)\n", def.Name())
   252  	fmt.Fprintf(g.w, "}\n\n")
   253  }