github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/optgen/cmd/langgen/exprs_gen.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package main
    12  
    13  import (
    14  	"fmt"
    15  	"io"
    16  	"unicode"
    17  	"unicode/utf8"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/sql/opt/optgen/lang"
    20  )
    21  
    22  // exprsGen generates the AST expression structs for the Optgen language, as
    23  // well as the Expr interface functions. It generates code using the AST that a
    24  // previous version of itself generated (compiler bootstrapping).
    25  type exprsGen struct {
    26  	compiled *lang.CompiledExpr
    27  	w        io.Writer
    28  }
    29  
    30  func (g *exprsGen) generate(compiled *lang.CompiledExpr, w io.Writer) {
    31  	g.compiled = compiled
    32  	g.w = w
    33  
    34  	fmt.Fprintf(g.w, "import (\n")
    35  	fmt.Fprintf(g.w, "  \"bytes\"\n")
    36  	fmt.Fprintf(g.w, "  \"fmt\"\n")
    37  	fmt.Fprintf(g.w, ")\n\n")
    38  
    39  	for _, define := range g.compiled.Defines {
    40  		g.genExprType(define)
    41  		g.genOpFunc(define)
    42  		g.genChildCountFunc(define)
    43  		g.genChildFunc(define)
    44  		g.genChildNameFunc(define)
    45  		g.genValueFunc(define)
    46  		g.genVisitFunc(define)
    47  		g.genSourceFunc(define)
    48  		g.genInferredType(define)
    49  		g.genStringFunc(define)
    50  		g.genFormatFunc(define)
    51  	}
    52  }
    53  
    54  // type SomeExpr struct {
    55  //   FieldName FieldType
    56  // }
    57  func (g *exprsGen) genExprType(define *lang.DefineExpr) {
    58  	exprType := fmt.Sprintf("%sExpr", define.Name)
    59  
    60  	// Generate the expression type.
    61  	if isValueType(define) {
    62  		fmt.Fprintf(g.w, "type %s %s\n", exprType, g.translateType(valueType(define)))
    63  	} else if isSliceType(define) {
    64  		fmt.Fprintf(g.w, "type %s []%s\n", exprType, g.translateType(sliceElementType(define)))
    65  	} else {
    66  		fmt.Fprintf(g.w, "type %s struct {\n", exprType)
    67  
    68  		for _, field := range define.Fields {
    69  			fmt.Fprintf(g.w, "  %s %s\n", field.Name, g.translateType(string(field.Type)))
    70  		}
    71  
    72  		if hasSourceField(define) {
    73  			fmt.Fprintf(g.w, "  Src *SourceLoc\n")
    74  		}
    75  		if define.Tags.Contains("HasType") {
    76  			fmt.Fprintf(g.w, "  Typ DataType")
    77  		}
    78  		fmt.Fprintf(g.w, "}\n\n")
    79  	}
    80  }
    81  
    82  // func (e *SomeExpr) Op() Operator {
    83  //   return SomeOp
    84  // }
    85  func (g *exprsGen) genOpFunc(define *lang.DefineExpr) {
    86  	exprType := fmt.Sprintf("%sExpr", define.Name)
    87  	opType := fmt.Sprintf("%sOp", define.Name)
    88  
    89  	fmt.Fprintf(g.w, "func (e *%s) Op() Operator {\n", exprType)
    90  	fmt.Fprintf(g.w, "  return %s\n", opType)
    91  	fmt.Fprintf(g.w, "}\n\n")
    92  }
    93  
    94  // func (e *SomeExpr) ChildCount() int {
    95  //   return 1
    96  // }
    97  func (g *exprsGen) genChildCountFunc(define *lang.DefineExpr) {
    98  	exprType := fmt.Sprintf("%sExpr", define.Name)
    99  
   100  	// ChildCount method.
   101  	fmt.Fprintf(g.w, "func (e *%s) ChildCount() int {\n", exprType)
   102  	if isSliceType(define) {
   103  		fmt.Fprintf(g.w, "  return len(*e)\n")
   104  	} else if isValueType(define) {
   105  		fmt.Fprintf(g.w, "  return 0\n")
   106  	} else {
   107  		fmt.Fprintf(g.w, "  return %d\n", len(define.Fields))
   108  	}
   109  	fmt.Fprintf(g.w, "}\n\n")
   110  }
   111  
   112  // func (e *SomeExpr) Child(nth int) Expr {
   113  //   switch nth {
   114  //   case 0:
   115  //     return e.FieldName
   116  //   }
   117  //   panic(fmt.Sprintf("child index %d is out of range", nth))
   118  // }
   119  func (g *exprsGen) genChildFunc(define *lang.DefineExpr) {
   120  	exprType := fmt.Sprintf("%sExpr", define.Name)
   121  
   122  	fmt.Fprintf(g.w, "func (e *%s) Child(nth int) Expr {\n", exprType)
   123  
   124  	if isSliceType(define) {
   125  		if g.isValueOrSliceType(sliceElementType(define)) {
   126  			fmt.Fprintf(g.w, "  return &(*e)[nth]\n")
   127  		} else {
   128  			fmt.Fprintf(g.w, "  return (*e)[nth]\n")
   129  		}
   130  	} else if isValueType(define) {
   131  		fmt.Fprintf(g.w, "  panic(fmt.Sprintf(\"child index %%d is out of range\", nth))\n")
   132  	} else {
   133  		if len(define.Fields) != 0 {
   134  			fmt.Fprintf(g.w, "  switch nth {\n")
   135  			for i, field := range define.Fields {
   136  				fmt.Fprintf(g.w, "  case %d:\n", i)
   137  
   138  				if g.isValueOrSliceType(string(field.Type)) {
   139  					fmt.Fprintf(g.w, "    return &e.%s\n", field.Name)
   140  				} else {
   141  					fmt.Fprintf(g.w, "    return e.%s\n", field.Name)
   142  				}
   143  			}
   144  			fmt.Fprintf(g.w, "  }\n")
   145  		}
   146  
   147  		fmt.Fprintf(g.w, "  panic(fmt.Sprintf(\"child index %%d is out of range\", nth))\n")
   148  	}
   149  
   150  	fmt.Fprintf(g.w, "}\n\n")
   151  }
   152  
   153  // func (e *SomeExpr) ChildName(nth int) string {
   154  //   switch nth {
   155  //   case 0:
   156  //     return "FieldName"
   157  //   }
   158  //   panic(fmt.Sprintf("child index %d is out of range", nth))
   159  // }
   160  func (g *exprsGen) genChildNameFunc(define *lang.DefineExpr) {
   161  	exprType := fmt.Sprintf("%sExpr", define.Name)
   162  
   163  	fmt.Fprintf(g.w, "func (e *%s) ChildName(nth int) string {\n", exprType)
   164  
   165  	if !isSliceType(define) && !isValueType(define) && len(define.Fields) != 0 {
   166  		fmt.Fprintf(g.w, "  switch (nth) {\n")
   167  		for i, field := range define.Fields {
   168  			fmt.Fprintf(g.w, "  case %d:\n", i)
   169  
   170  			fmt.Fprintf(g.w, "    return \"%s\"\n", field.Name)
   171  		}
   172  		fmt.Fprintf(g.w, "  }\n")
   173  	}
   174  
   175  	fmt.Fprintf(g.w, "  return \"\"\n")
   176  	fmt.Fprintf(g.w, "}\n\n")
   177  }
   178  
   179  // func (e *SomeExpr) Value() interface{} {
   180  //   return string(*e)
   181  // }
   182  func (g *exprsGen) genValueFunc(define *lang.DefineExpr) {
   183  	exprType := fmt.Sprintf("%sExpr", define.Name)
   184  
   185  	fmt.Fprintf(g.w, "func (e *%s) Value() interface{} {\n", exprType)
   186  	if isValueType(define) {
   187  		fmt.Fprintf(g.w, "  return %s(*e)\n", valueType(define))
   188  	} else {
   189  		fmt.Fprintf(g.w, "  return nil\n")
   190  	}
   191  	fmt.Fprintf(g.w, "}\n\n")
   192  }
   193  
   194  // func (e *SomeExpr) Visit(visit VisitFunc) Expr {
   195  //   children := visitChildren(e, visit)
   196  //   if children != nil {
   197  //     return &SomeExpr{FieldName: children[0].(*FieldType)}
   198  //   }
   199  //   return e
   200  // }
   201  func (g *exprsGen) genVisitFunc(define *lang.DefineExpr) {
   202  	exprType := fmt.Sprintf("%sExpr", define.Name)
   203  
   204  	fmt.Fprintf(g.w, "func (e *%s) Visit(visit VisitFunc) Expr {\n", exprType)
   205  
   206  	// Value type definition has no children.
   207  	if !isValueType(define) && len(define.Fields) != 0 {
   208  		fmt.Fprintf(g.w, "  children := visitChildren(e, visit)\n")
   209  		fmt.Fprintf(g.w, "  if children != nil {\n")
   210  
   211  		if isSliceType(define) {
   212  			elemType := g.translateType(sliceElementType(define))
   213  			if elemType != "Expr" {
   214  				fmt.Fprintf(g.w, "    typedChildren := make(%s, len(children))\n", exprType)
   215  				fmt.Fprintf(g.w, "    for i := 0; i < len(children); i++ {\n")
   216  				if g.isValueOrSliceType(sliceElementType(define)) {
   217  					fmt.Fprintf(g.w, "      typedChildren[i] = *children[i].(*%s)\n", elemType)
   218  				} else {
   219  					fmt.Fprintf(g.w, "      typedChildren[i] = children[i].(%s)\n", elemType)
   220  				}
   221  				fmt.Fprintf(g.w, "    }\n")
   222  				fmt.Fprintf(g.w, "    return &typedChildren\n")
   223  			} else {
   224  				fmt.Fprintf(g.w, "    typedChildren := %s(children)\n", exprType)
   225  				fmt.Fprintf(g.w, "    return &typedChildren\n")
   226  			}
   227  		} else {
   228  			fmt.Fprintf(g.w, "    return &%s{", exprType)
   229  
   230  			for i, field := range define.Fields {
   231  				fieldType := g.translateType(string(field.Type))
   232  
   233  				if i != 0 {
   234  					fmt.Fprintf(g.w, ", ")
   235  				}
   236  
   237  				if g.isValueOrSliceType(string(field.Type)) {
   238  					fmt.Fprintf(g.w, "%s: *children[%d].(*%s)", field.Name, i, fieldType)
   239  				} else if field.Type == "Expr" {
   240  					fmt.Fprintf(g.w, "%s: children[%d]", field.Name, i)
   241  				} else {
   242  					fmt.Fprintf(g.w, "%s: children[%d].(%s)", field.Name, i, fieldType)
   243  				}
   244  			}
   245  
   246  			// Propagate source file, line, pos.
   247  			if hasSourceField(define) {
   248  				fmt.Fprintf(g.w, ", Src: e.Source()")
   249  			}
   250  
   251  			fmt.Fprintf(g.w, "}\n")
   252  		}
   253  
   254  		fmt.Fprintf(g.w, "  }\n")
   255  	}
   256  
   257  	fmt.Fprintf(g.w, "  return e\n")
   258  	fmt.Fprintf(g.w, "}\n\n")
   259  }
   260  
   261  // func (e *SomeExpr) Source() *SourceLoc {
   262  //   return e.Src
   263  // }
   264  func (g *exprsGen) genSourceFunc(define *lang.DefineExpr) {
   265  	exprType := fmt.Sprintf("%sExpr", define.Name)
   266  
   267  	fmt.Fprintf(g.w, "func (e *%s) Source() *SourceLoc {\n", exprType)
   268  	if hasSourceField(define) {
   269  		fmt.Fprintf(g.w, "  return e.Src\n")
   270  	} else {
   271  		fmt.Fprintf(g.w, "  return nil\n")
   272  	}
   273  	fmt.Fprintf(g.w, "}\n\n")
   274  }
   275  
   276  // func (e *SomeExpr) InferredType() DataType {
   277  //   return e.Typ
   278  // }
   279  func (g *exprsGen) genInferredType(define *lang.DefineExpr) {
   280  	exprType := fmt.Sprintf("%sExpr", define.Name)
   281  
   282  	fmt.Fprintf(g.w, "func (e *%s) InferredType() DataType {\n", exprType)
   283  	if define.Tags.Contains("HasType") {
   284  		fmt.Fprintf(g.w, "  return e.Typ\n")
   285  	} else if isValueType(define) {
   286  		fmt.Fprintf(g.w, "  return %sDataType\n", title(g.translateType(valueType(define))))
   287  	} else {
   288  		fmt.Fprintf(g.w, "  return AnyDataType\n")
   289  	}
   290  	fmt.Fprintf(g.w, "}\n\n")
   291  }
   292  
   293  // func (e *SomeExpr) String() string {
   294  //   var buf bytes.Buffer
   295  //   e.Format(&buf, 0)
   296  //   return buf.String()
   297  // }
   298  func (g *exprsGen) genStringFunc(define *lang.DefineExpr) {
   299  	exprType := fmt.Sprintf("%sExpr", define.Name)
   300  
   301  	fmt.Fprintf(g.w, "func (e *%s) String() string {\n", exprType)
   302  	fmt.Fprintf(g.w, "  var buf bytes.Buffer\n")
   303  	fmt.Fprintf(g.w, "  e.Format(&buf, 0)\n")
   304  	fmt.Fprintf(g.w, "  return buf.String()\n")
   305  	fmt.Fprintf(g.w, "}\n\n")
   306  }
   307  
   308  // func (e *SomeExpr) Format(buf *bytes.Buffer, level int) {
   309  //   formatExpr(e, buf, level)
   310  // }
   311  func (g *exprsGen) genFormatFunc(define *lang.DefineExpr) {
   312  	exprType := fmt.Sprintf("%sExpr", define.Name)
   313  
   314  	fmt.Fprintf(g.w, "func (e *%s) Format(buf *bytes.Buffer, level int) {\n", exprType)
   315  	fmt.Fprintf(g.w, "  formatExpr(e, buf, level)\n")
   316  	fmt.Fprintf(g.w, "}\n\n")
   317  }
   318  
   319  func (g *exprsGen) translateType(typ string) string {
   320  	switch typ {
   321  	case "Expr":
   322  		return typ
   323  
   324  	case "string":
   325  		return typ
   326  
   327  	case "int64":
   328  		return typ
   329  	}
   330  
   331  	if g.isValueOrSliceType(typ) {
   332  		return fmt.Sprintf("%sExpr", typ)
   333  	}
   334  
   335  	return fmt.Sprintf("*%sExpr", typ)
   336  }
   337  
   338  func (g *exprsGen) isValueOrSliceType(typ string) bool {
   339  	// Expr is built-in type, without explicit definition.
   340  	if typ == "Expr" {
   341  		return false
   342  	}
   343  
   344  	// Pass slices and value types by value.
   345  	define := g.compiled.LookupDefine(typ)
   346  	if define == nil {
   347  		panic(fmt.Sprintf("could not find define for type %s", typ))
   348  	}
   349  	return isValueType(define) || isSliceType(define)
   350  }
   351  
   352  // isValueType returns true if the define statement is defining a value
   353  // expression, which is a leaf expression that is equivalent to a primitive
   354  // type like string or int. These types return non-nil for the Value function.
   355  func isValueType(d *lang.DefineExpr) bool {
   356  	return d.Tags.Contains("Value")
   357  }
   358  
   359  // isSliceType returns true if the define statement is defining a slice
   360  // expression, which is an expression that stores a slice of expressions of
   361  // some other type, like []*RuleExpr or []TagExpr.
   362  func isSliceType(d *lang.DefineExpr) bool {
   363  	return d.Tags.Contains("Slice")
   364  }
   365  
   366  // valueType returns the name of the primitive type which the defined type
   367  // is equivalent to, like string or int.
   368  func valueType(d *lang.DefineExpr) string {
   369  	if d.Fields[0].Name != "Value" {
   370  		panic(fmt.Sprintf("expected 'Value' field name, found %s", d.Fields[0].Name))
   371  	}
   372  	return string(d.Fields[0].Type)
   373  }
   374  
   375  // sliceElementType returns the type of elements in the slice expression.
   376  func sliceElementType(d *lang.DefineExpr) string {
   377  	if d.Fields[0].Name != "Element" {
   378  		panic(fmt.Sprintf("expected 'Element' field name, found %s", d.Fields[0].Name))
   379  	}
   380  	return string(d.Fields[0].Type)
   381  }
   382  
   383  // hasSourceField returns true if the defined expression has a Src field that
   384  // stores the original source information (file, line, pos).
   385  func hasSourceField(d *lang.DefineExpr) bool {
   386  	return !isValueType(d) && !isSliceType(d)
   387  }
   388  
   389  // title returns the given string with its first letter capitalized.
   390  func title(name string) string {
   391  	rune, size := utf8.DecodeRuneInString(name)
   392  	return fmt.Sprintf("%c%s", unicode.ToUpper(rune), name[size:])
   393  }