github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/framer_gen.go (about)

     1  package support
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"math"
     7  	"strings"
     8  )
     9  
    10  type FramerGen struct {
    11  	w     io.Writer
    12  	defs  []frameDef
    13  	limit int
    14  }
    15  
    16  func (g *FramerGen) Generate(defines GenDefs, w io.Writer) {
    17  	g.w = w
    18  	if g.limit == 0 {
    19  		g.limit = math.MaxInt32
    20  	}
    21  	g.defs = getDefs(g.limit)
    22  	g.generate()
    23  }
    24  
    25  func (g *FramerGen) generate() {
    26  	g.genImports()
    27  	for _, def := range g.defs {
    28  		g.genFramerType(def)
    29  		g.genNewFramer(def)
    30  	}
    31  }
    32  
    33  func (g *FramerGen) genImports() {
    34  	fmt.Fprintf(g.w, "import (\n")
    35  	fmt.Fprintf(g.w, "  \"github.com/dolthub/go-mysql-server/sql\"\n")
    36  	fmt.Fprintf(g.w, "  \"github.com/dolthub/go-mysql-server/sql/expression\"\n")
    37  	fmt.Fprintf(g.w, ")\n\n")
    38  }
    39  
    40  func (g *FramerGen) genFramerType(def frameDef) {
    41  	fmt.Fprintf(g.w, "type %sFramer struct {\n", def.Name())
    42  	switch def.unit {
    43  	case rows:
    44  		fmt.Fprintf(g.w, "  rowFramerBase\n")
    45  	case rang:
    46  		fmt.Fprintf(g.w, "  rangeFramerBase\n")
    47  	}
    48  	fmt.Fprintf(g.w, "}\n\n")
    49  
    50  	fmt.Fprintf(g.w, "var _ sql.WindowFramer = (*%sFramer)(nil)\n\n", def.Name())
    51  
    52  }
    53  
    54  func (g *FramerGen) genNewFramer(def frameDef) {
    55  	framerName := fmt.Sprintf("%sFramer", def.Name())
    56  	fmt.Fprintf(g.w, "func New%sFramer(frame sql.WindowFrame, window *sql.WindowDefinition) (sql.WindowFramer, error) {\n", def.Name())
    57  
    58  	for _, a := range def.Args() {
    59  		switch a.argType() {
    60  		case "sql.Expression":
    61  			switch def.unit {
    62  			case rows:
    63  				fmt.Fprintf(g.w, "  %s, err := expression.LiteralToInt(frame.%s())\n", a, strings.Title(a.String()))
    64  				fmt.Fprintf(g.w, "  if err != nil {\n")
    65  				fmt.Fprintf(g.w, "    return nil, err\n")
    66  				fmt.Fprintf(g.w, "  }\n")
    67  			case rang:
    68  				fmt.Fprintf(g.w, "  %s := frame.%s()\n", a, strings.Title(a.String()))
    69  			}
    70  		case "bool":
    71  			fmt.Fprintf(g.w, "  %s := true\n", a)
    72  		}
    73  	}
    74  
    75  	orderByRequired := def.unit == rang &&
    76  		((def.start != unboundedPreceding && def.start != startCurrentRow) ||
    77  			(def.end != unboundedFollowing && def.end != endCurrentRow))
    78  
    79  	if orderByRequired {
    80  		fmt.Fprintf(g.w, "  if len(window.OrderBy) != 1 {\n")
    81  		fmt.Fprintf(g.w, "    return nil, ErrRangeInvalidOrderBy.New(len(window.OrderBy.ToExpressions()))\n")
    82  		fmt.Fprintf(g.w, "  }\n")
    83  	}
    84  
    85  	if def.unit == rang {
    86  		fmt.Fprintf(g.w, "  var orderBy sql.Expression\n")
    87  		fmt.Fprintf(g.w, "  if len(window.OrderBy) > 0 {\n")
    88  		fmt.Fprintf(g.w, "    orderBy = window.OrderBy.ToExpressions()[0]\n")
    89  		fmt.Fprintf(g.w, "  }\n")
    90  	}
    91  
    92  	fmt.Fprintf(g.w, "  return &%s{\n", framerName)
    93  	switch def.unit {
    94  	case rows:
    95  		fmt.Fprintf(g.w, "    rowFramerBase{\n")
    96  	case rang:
    97  		fmt.Fprintf(g.w, "    rangeFramerBase{\n")
    98  		fmt.Fprintf(g.w, "      orderBy: orderBy,\n")
    99  	}
   100  
   101  	for _, a := range def.Args() {
   102  		fmt.Fprintf(g.w, "      %s: %s,\n", a, a)
   103  	}
   104  
   105  	fmt.Fprintf(g.w, "    },\n")
   106  	fmt.Fprintf(g.w, "  }, nil\n")
   107  	fmt.Fprintf(g.w, "}\n\n")
   108  }