github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/memo_gen.go (about)

     1  package support
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"strings"
     9  
    10  	"gopkg.in/yaml.v3"
    11  )
    12  
    13  //go:generate go run ../optgen/main.go -out ../../../sql/memo/memo.og.go -pkg memo memo
    14  
    15  type MemoExprs struct {
    16  	Exprs []ExprDef `yaml:"exprs"`
    17  }
    18  
    19  type ExprDef struct {
    20  	Name        string      `yaml:"name"`
    21  	SourceType  string      `yaml:"sourceType"`
    22  	Join        bool        `yaml:"join"`
    23  	Attrs       [][2]string `yaml:"attrs"`
    24  	Unary       bool        `yaml:"unary"`
    25  	SkipExec    bool        `yaml:"skipExec"`
    26  	Binary      bool        `yaml:"binary"`
    27  	SkipName    bool        `yaml:"skipName"`
    28  	SkipTableId bool        `yaml:"skipTableId"`
    29  }
    30  
    31  func DecodeMemoExprs(path string) (MemoExprs, error) {
    32  	contents, err := os.ReadFile(path)
    33  	if err != nil {
    34  		return MemoExprs{}, err
    35  	}
    36  	dec := yaml.NewDecoder(bytes.NewReader(contents))
    37  	dec.KnownFields(true)
    38  	var res MemoExprs
    39  	return res, dec.Decode(&res)
    40  }
    41  
    42  var _ GenDefs = (*MemoExprs)(nil)
    43  
    44  type MemoGen struct {
    45  	defines []ExprDef
    46  	w       io.Writer
    47  }
    48  
    49  func (g *MemoGen) Generate(defines GenDefs, w io.Writer) {
    50  	g.defines = defines.(MemoExprs).Exprs
    51  
    52  	g.w = w
    53  
    54  	g.genImport()
    55  	for _, define := range g.defines {
    56  		g.genType(define)
    57  		g.genRelInterfaces(define)
    58  
    59  		g.genStringer(define)
    60  		if define.SourceType != "" {
    61  			g.genSourceRelInterface(define)
    62  		}
    63  		if define.Join {
    64  			g.genJoinRelInterface(define)
    65  		} else if define.Binary {
    66  			g.genBinaryGroupInterface(define)
    67  		} else if define.Unary {
    68  			g.genUnaryGroupInterface(define)
    69  		} else {
    70  			g.genChildlessGroupInterface(define)
    71  		}
    72  	}
    73  	g.genFormatters(g.defines)
    74  
    75  }
    76  
    77  func (g *MemoGen) genImport() {
    78  	fmt.Fprintf(g.w, "import (\n")
    79  	fmt.Fprintf(g.w, "  \"fmt\"\n")
    80  	fmt.Fprintf(g.w, "  \"strings\"\n")
    81  	fmt.Fprintf(g.w, "  \"github.com/dolthub/go-mysql-server/sql\"\n")
    82  	fmt.Fprintf(g.w, "  \"github.com/dolthub/go-mysql-server/sql/plan\"\n")
    83  	fmt.Fprintf(g.w, ")\n\n")
    84  }
    85  
    86  func (g *MemoGen) genType(define ExprDef) {
    87  	fmt.Fprintf(g.w, "type %s struct {\n", strings.Title(define.Name))
    88  	if define.SourceType != "" {
    89  		fmt.Fprintf(g.w, "  *sourceBase\n")
    90  		fmt.Fprintf(g.w, "  Table %s\n", define.SourceType)
    91  	} else if define.Join {
    92  		fmt.Fprintf(g.w, "  *JoinBase\n")
    93  	} else if define.Unary {
    94  		fmt.Fprintf(g.w, "  *relBase\n")
    95  		fmt.Fprintf(g.w, "  Child *ExprGroup\n")
    96  	} else if define.Binary {
    97  		fmt.Fprintf(g.w, "  *relBase\n")
    98  		fmt.Fprintf(g.w, "  Left *ExprGroup\n")
    99  		fmt.Fprintf(g.w, "  Right *ExprGroup\n")
   100  	}
   101  	for _, attr := range define.Attrs {
   102  		fmt.Fprintf(g.w, "  %s %s\n", strings.Title(attr[0]), attr[1])
   103  	}
   104  
   105  	fmt.Fprintf(g.w, "}\n\n")
   106  }
   107  
   108  func (g *MemoGen) genRelInterfaces(define ExprDef) {
   109  	fmt.Fprintf(g.w, "var _ RelExpr = (*%s)(nil)\n", define.Name)
   110  	if define.SourceType != "" {
   111  		fmt.Fprintf(g.w, "var _ SourceRel = (*%s)(nil)\n", define.Name)
   112  	} else if define.Join {
   113  		fmt.Fprintf(g.w, "var _ JoinRel = (*%s)(nil)\n", define.Name)
   114  	} else if define.Unary || define.Binary {
   115  	} else {
   116  		panic("unreachable")
   117  	}
   118  	fmt.Fprintf(g.w, "\n")
   119  }
   120  
   121  func (g *MemoGen) genScalarInterfaces(define ExprDef) {
   122  	fmt.Fprintf(g.w, "var _ ScalarExpr = (*%s)(nil)\n", define.Name)
   123  
   124  	fmt.Fprintf(g.w, "\n")
   125  
   126  	fmt.Fprintf(g.w, "func (r *%s) ExprId() ScalarExprId {\n", define.Name)
   127  	fmt.Fprintf(g.w, "  return ScalarExpr%s\n", strings.Title(define.Name))
   128  	fmt.Fprintf(g.w, "}\n\n")
   129  }
   130  
   131  func (g *MemoGen) genStringer(define ExprDef) {
   132  	fmt.Fprintf(g.w, "func (r *%s) String() string {\n", define.Name)
   133  	fmt.Fprintf(g.w, "  return FormatExpr(r)\n")
   134  	fmt.Fprintf(g.w, "}\n\n")
   135  }
   136  
   137  func (g *MemoGen) genSourceRelInterface(define ExprDef) {
   138  	fmt.Fprintf(g.w, "func (r *%s) Name() string {\n", define.Name)
   139  	if !define.SkipName {
   140  		fmt.Fprintf(g.w, "  return strings.ToLower(r.Table.Name())\n")
   141  	} else {
   142  		fmt.Fprintf(g.w, "  return \"\"\n")
   143  	}
   144  	fmt.Fprintf(g.w, "}\n\n")
   145  
   146  	fmt.Fprintf(g.w, "func (r *%s) TableId() sql.TableId {\n", define.Name)
   147  	fmt.Fprintf(g.w, "  return TableIdForSource(r.g.Id)\n")
   148  	fmt.Fprintf(g.w, "}\n\n")
   149  
   150  	fmt.Fprintf(g.w, "func (r *%s) TableIdNode() plan.TableIdNode {\n", define.Name)
   151  	if define.SkipTableId {
   152  		fmt.Fprintf(g.w, "  return nil\n")
   153  	} else {
   154  		fmt.Fprintf(g.w, "  return r.Table\n")
   155  	}
   156  	fmt.Fprintf(g.w, "}\n\n")
   157  
   158  	fmt.Fprintf(g.w, "func (r *%s) OutputCols() sql.Schema {\n", define.Name)
   159  	fmt.Fprintf(g.w, "  return r.Table.Schema()\n")
   160  	fmt.Fprintf(g.w, "}\n\n")
   161  }
   162  
   163  func (g *MemoGen) genJoinRelInterface(define ExprDef) {
   164  	fmt.Fprintf(g.w, "func (r *%s) JoinPrivate() *JoinBase {\n", define.Name)
   165  	fmt.Fprintf(g.w, "  return r.JoinBase\n")
   166  	fmt.Fprintf(g.w, "}\n\n")
   167  }
   168  
   169  func (g *MemoGen) genBinaryGroupInterface(define ExprDef) {
   170  	fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name)
   171  	fmt.Fprintf(g.w, "  return []*ExprGroup{r.Left, r.Right}\n")
   172  	fmt.Fprintf(g.w, "}\n\n")
   173  }
   174  
   175  func (g *MemoGen) genChildlessGroupInterface(define ExprDef) {
   176  	fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name)
   177  	fmt.Fprintf(g.w, "  return nil\n")
   178  	fmt.Fprintf(g.w, "}\n\n")
   179  }
   180  
   181  func (g *MemoGen) genUnaryGroupInterface(define ExprDef) {
   182  	fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name)
   183  	fmt.Fprintf(g.w, "  return []*ExprGroup{r.Child}\n")
   184  	fmt.Fprintf(g.w, "}\n\n")
   185  
   186  	fmt.Fprintf(g.w, "func (r *%s) outputCols() sql.ColSet {\n", define.Name)
   187  	switch define.Name {
   188  	case "Project":
   189  		fmt.Fprintf(g.w, "  return getProjectColset(r)\n")
   190  
   191  	default:
   192  		fmt.Fprintf(g.w, "  return r.Child.RelProps.OutputCols()\n")
   193  	}
   194  
   195  	fmt.Fprintf(g.w, "}\n\n")
   196  
   197  }
   198  
   199  func (g *MemoGen) genFormatters(defines []ExprDef) {
   200  	// printer
   201  	fmt.Fprintf(g.w, "func FormatExpr(r exprType) string {\n")
   202  	fmt.Fprintf(g.w, "  switch r := r.(type) {\n")
   203  	for _, d := range defines {
   204  		loweredName := strings.ToLower(d.Name)
   205  		fmt.Fprintf(g.w, "  case *%s:\n", d.Name)
   206  		if loweredName == "indexscan" {
   207  			fmt.Fprintf(g.w, "    if r.Alias != \"\" {\n")
   208  			fmt.Fprintf(g.w, "      return fmt.Sprintf(\"%s: %%s\", r.Alias)\n", loweredName)
   209  			fmt.Fprintf(g.w, "    }\n")
   210  		}
   211  		if d.SourceType != "" {
   212  			fmt.Fprintf(g.w, "    return fmt.Sprintf(\"%s: %%s\", r.Name())\n", loweredName)
   213  		} else if d.Join || d.Binary {
   214  			fmt.Fprintf(g.w, "    return fmt.Sprintf(\"%s %%d %%d\", r.Left.Id, r.Right.Id)\n", loweredName)
   215  		} else if d.Unary {
   216  			fmt.Fprintf(g.w, "    return fmt.Sprintf(\"%s: %%d\", r.Child.Id)\n", loweredName)
   217  		} else {
   218  			panic("unreachable")
   219  		}
   220  	}
   221  	fmt.Fprintf(g.w, "  default:\n")
   222  	fmt.Fprintf(g.w, "    panic(fmt.Sprintf(\"unknown RelExpr type: %%T\", r))\n")
   223  	fmt.Fprintf(g.w, "  }\n")
   224  	fmt.Fprintf(g.w, "}\n\n")
   225  
   226  	// to sqlNode
   227  	fmt.Fprintf(g.w, "func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) {\n")
   228  	fmt.Fprintf(g.w, "  var result sql.Node\n")
   229  	fmt.Fprintf(g.w, "  var err error\n\n")
   230  	fmt.Fprintf(g.w, "  switch r := r.(type) {\n")
   231  	for _, d := range defines {
   232  		if d.SkipExec {
   233  			continue
   234  		}
   235  		fmt.Fprintf(g.w, "  case *%s:\n", d.Name)
   236  		fmt.Fprintf(g.w, "  result, err = b.build%s(r, children...)\n", strings.Title(d.Name))
   237  	}
   238  	fmt.Fprintf(g.w, "  default:\n")
   239  	fmt.Fprintf(g.w, "    panic(fmt.Sprintf(\"unknown RelExpr type: %%T\", r))\n")
   240  	fmt.Fprintf(g.w, "  }\n\n")
   241  	fmt.Fprintf(g.w, "  if err != nil {\n")
   242  	fmt.Fprintf(g.w, "    return nil, err\n")
   243  	fmt.Fprintf(g.w, "  }\n\n")
   244  	fmt.Fprintf(g.w, "if withDescribeStats, ok := result.(sql.WithDescribeStats); ok {\n")
   245  	fmt.Fprintf(g.w, "	withDescribeStats.SetDescribeStats(*DescribeStats(r))\n")
   246  	fmt.Fprintf(g.w, "}\n")
   247  	fmt.Fprintf(g.w, "  result, err = r.Group().finalize(result)\n")
   248  	fmt.Fprintf(g.w, "  if err != nil {\n")
   249  	fmt.Fprintf(g.w, "    return nil, err\n")
   250  	fmt.Fprintf(g.w, "  }\n")
   251  	fmt.Fprintf(g.w, "  return result, nil\n")
   252  	fmt.Fprintf(g.w, "}\n\n")
   253  }