github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/optgen/cmd/langgen/exprs_gen.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package main 12 13 import ( 14 "fmt" 15 "io" 16 "unicode" 17 "unicode/utf8" 18 19 "github.com/cockroachdb/cockroach/pkg/sql/opt/optgen/lang" 20 ) 21 22 // exprsGen generates the AST expression structs for the Optgen language, as 23 // well as the Expr interface functions. It generates code using the AST that a 24 // previous version of itself generated (compiler bootstrapping). 25 type exprsGen struct { 26 compiled *lang.CompiledExpr 27 w io.Writer 28 } 29 30 func (g *exprsGen) generate(compiled *lang.CompiledExpr, w io.Writer) { 31 g.compiled = compiled 32 g.w = w 33 34 fmt.Fprintf(g.w, "import (\n") 35 fmt.Fprintf(g.w, " \"bytes\"\n") 36 fmt.Fprintf(g.w, " \"fmt\"\n") 37 fmt.Fprintf(g.w, ")\n\n") 38 39 for _, define := range g.compiled.Defines { 40 g.genExprType(define) 41 g.genOpFunc(define) 42 g.genChildCountFunc(define) 43 g.genChildFunc(define) 44 g.genChildNameFunc(define) 45 g.genValueFunc(define) 46 g.genVisitFunc(define) 47 g.genSourceFunc(define) 48 g.genInferredType(define) 49 g.genStringFunc(define) 50 g.genFormatFunc(define) 51 } 52 } 53 54 // type SomeExpr struct { 55 // FieldName FieldType 56 // } 57 func (g *exprsGen) genExprType(define *lang.DefineExpr) { 58 exprType := fmt.Sprintf("%sExpr", define.Name) 59 60 // Generate the expression type. 61 if isValueType(define) { 62 fmt.Fprintf(g.w, "type %s %s\n", exprType, g.translateType(valueType(define))) 63 } else if isSliceType(define) { 64 fmt.Fprintf(g.w, "type %s []%s\n", exprType, g.translateType(sliceElementType(define))) 65 } else { 66 fmt.Fprintf(g.w, "type %s struct {\n", exprType) 67 68 for _, field := range define.Fields { 69 fmt.Fprintf(g.w, " %s %s\n", field.Name, g.translateType(string(field.Type))) 70 } 71 72 if hasSourceField(define) { 73 fmt.Fprintf(g.w, " Src *SourceLoc\n") 74 } 75 if define.Tags.Contains("HasType") { 76 fmt.Fprintf(g.w, " Typ DataType") 77 } 78 fmt.Fprintf(g.w, "}\n\n") 79 } 80 } 81 82 // func (e *SomeExpr) Op() Operator { 83 // return SomeOp 84 // } 85 func (g *exprsGen) genOpFunc(define *lang.DefineExpr) { 86 exprType := fmt.Sprintf("%sExpr", define.Name) 87 opType := fmt.Sprintf("%sOp", define.Name) 88 89 fmt.Fprintf(g.w, "func (e *%s) Op() Operator {\n", exprType) 90 fmt.Fprintf(g.w, " return %s\n", opType) 91 fmt.Fprintf(g.w, "}\n\n") 92 } 93 94 // func (e *SomeExpr) ChildCount() int { 95 // return 1 96 // } 97 func (g *exprsGen) genChildCountFunc(define *lang.DefineExpr) { 98 exprType := fmt.Sprintf("%sExpr", define.Name) 99 100 // ChildCount method. 101 fmt.Fprintf(g.w, "func (e *%s) ChildCount() int {\n", exprType) 102 if isSliceType(define) { 103 fmt.Fprintf(g.w, " return len(*e)\n") 104 } else if isValueType(define) { 105 fmt.Fprintf(g.w, " return 0\n") 106 } else { 107 fmt.Fprintf(g.w, " return %d\n", len(define.Fields)) 108 } 109 fmt.Fprintf(g.w, "}\n\n") 110 } 111 112 // func (e *SomeExpr) Child(nth int) Expr { 113 // switch nth { 114 // case 0: 115 // return e.FieldName 116 // } 117 // panic(fmt.Sprintf("child index %d is out of range", nth)) 118 // } 119 func (g *exprsGen) genChildFunc(define *lang.DefineExpr) { 120 exprType := fmt.Sprintf("%sExpr", define.Name) 121 122 fmt.Fprintf(g.w, "func (e *%s) Child(nth int) Expr {\n", exprType) 123 124 if isSliceType(define) { 125 if g.isValueOrSliceType(sliceElementType(define)) { 126 fmt.Fprintf(g.w, " return &(*e)[nth]\n") 127 } else { 128 fmt.Fprintf(g.w, " return (*e)[nth]\n") 129 } 130 } else if isValueType(define) { 131 fmt.Fprintf(g.w, " panic(fmt.Sprintf(\"child index %%d is out of range\", nth))\n") 132 } else { 133 if len(define.Fields) != 0 { 134 fmt.Fprintf(g.w, " switch nth {\n") 135 for i, field := range define.Fields { 136 fmt.Fprintf(g.w, " case %d:\n", i) 137 138 if g.isValueOrSliceType(string(field.Type)) { 139 fmt.Fprintf(g.w, " return &e.%s\n", field.Name) 140 } else { 141 fmt.Fprintf(g.w, " return e.%s\n", field.Name) 142 } 143 } 144 fmt.Fprintf(g.w, " }\n") 145 } 146 147 fmt.Fprintf(g.w, " panic(fmt.Sprintf(\"child index %%d is out of range\", nth))\n") 148 } 149 150 fmt.Fprintf(g.w, "}\n\n") 151 } 152 153 // func (e *SomeExpr) ChildName(nth int) string { 154 // switch nth { 155 // case 0: 156 // return "FieldName" 157 // } 158 // panic(fmt.Sprintf("child index %d is out of range", nth)) 159 // } 160 func (g *exprsGen) genChildNameFunc(define *lang.DefineExpr) { 161 exprType := fmt.Sprintf("%sExpr", define.Name) 162 163 fmt.Fprintf(g.w, "func (e *%s) ChildName(nth int) string {\n", exprType) 164 165 if !isSliceType(define) && !isValueType(define) && len(define.Fields) != 0 { 166 fmt.Fprintf(g.w, " switch (nth) {\n") 167 for i, field := range define.Fields { 168 fmt.Fprintf(g.w, " case %d:\n", i) 169 170 fmt.Fprintf(g.w, " return \"%s\"\n", field.Name) 171 } 172 fmt.Fprintf(g.w, " }\n") 173 } 174 175 fmt.Fprintf(g.w, " return \"\"\n") 176 fmt.Fprintf(g.w, "}\n\n") 177 } 178 179 // func (e *SomeExpr) Value() interface{} { 180 // return string(*e) 181 // } 182 func (g *exprsGen) genValueFunc(define *lang.DefineExpr) { 183 exprType := fmt.Sprintf("%sExpr", define.Name) 184 185 fmt.Fprintf(g.w, "func (e *%s) Value() interface{} {\n", exprType) 186 if isValueType(define) { 187 fmt.Fprintf(g.w, " return %s(*e)\n", valueType(define)) 188 } else { 189 fmt.Fprintf(g.w, " return nil\n") 190 } 191 fmt.Fprintf(g.w, "}\n\n") 192 } 193 194 // func (e *SomeExpr) Visit(visit VisitFunc) Expr { 195 // children := visitChildren(e, visit) 196 // if children != nil { 197 // return &SomeExpr{FieldName: children[0].(*FieldType)} 198 // } 199 // return e 200 // } 201 func (g *exprsGen) genVisitFunc(define *lang.DefineExpr) { 202 exprType := fmt.Sprintf("%sExpr", define.Name) 203 204 fmt.Fprintf(g.w, "func (e *%s) Visit(visit VisitFunc) Expr {\n", exprType) 205 206 // Value type definition has no children. 207 if !isValueType(define) && len(define.Fields) != 0 { 208 fmt.Fprintf(g.w, " children := visitChildren(e, visit)\n") 209 fmt.Fprintf(g.w, " if children != nil {\n") 210 211 if isSliceType(define) { 212 elemType := g.translateType(sliceElementType(define)) 213 if elemType != "Expr" { 214 fmt.Fprintf(g.w, " typedChildren := make(%s, len(children))\n", exprType) 215 fmt.Fprintf(g.w, " for i := 0; i < len(children); i++ {\n") 216 if g.isValueOrSliceType(sliceElementType(define)) { 217 fmt.Fprintf(g.w, " typedChildren[i] = *children[i].(*%s)\n", elemType) 218 } else { 219 fmt.Fprintf(g.w, " typedChildren[i] = children[i].(%s)\n", elemType) 220 } 221 fmt.Fprintf(g.w, " }\n") 222 fmt.Fprintf(g.w, " return &typedChildren\n") 223 } else { 224 fmt.Fprintf(g.w, " typedChildren := %s(children)\n", exprType) 225 fmt.Fprintf(g.w, " return &typedChildren\n") 226 } 227 } else { 228 fmt.Fprintf(g.w, " return &%s{", exprType) 229 230 for i, field := range define.Fields { 231 fieldType := g.translateType(string(field.Type)) 232 233 if i != 0 { 234 fmt.Fprintf(g.w, ", ") 235 } 236 237 if g.isValueOrSliceType(string(field.Type)) { 238 fmt.Fprintf(g.w, "%s: *children[%d].(*%s)", field.Name, i, fieldType) 239 } else if field.Type == "Expr" { 240 fmt.Fprintf(g.w, "%s: children[%d]", field.Name, i) 241 } else { 242 fmt.Fprintf(g.w, "%s: children[%d].(%s)", field.Name, i, fieldType) 243 } 244 } 245 246 // Propagate source file, line, pos. 247 if hasSourceField(define) { 248 fmt.Fprintf(g.w, ", Src: e.Source()") 249 } 250 251 fmt.Fprintf(g.w, "}\n") 252 } 253 254 fmt.Fprintf(g.w, " }\n") 255 } 256 257 fmt.Fprintf(g.w, " return e\n") 258 fmt.Fprintf(g.w, "}\n\n") 259 } 260 261 // func (e *SomeExpr) Source() *SourceLoc { 262 // return e.Src 263 // } 264 func (g *exprsGen) genSourceFunc(define *lang.DefineExpr) { 265 exprType := fmt.Sprintf("%sExpr", define.Name) 266 267 fmt.Fprintf(g.w, "func (e *%s) Source() *SourceLoc {\n", exprType) 268 if hasSourceField(define) { 269 fmt.Fprintf(g.w, " return e.Src\n") 270 } else { 271 fmt.Fprintf(g.w, " return nil\n") 272 } 273 fmt.Fprintf(g.w, "}\n\n") 274 } 275 276 // func (e *SomeExpr) InferredType() DataType { 277 // return e.Typ 278 // } 279 func (g *exprsGen) genInferredType(define *lang.DefineExpr) { 280 exprType := fmt.Sprintf("%sExpr", define.Name) 281 282 fmt.Fprintf(g.w, "func (e *%s) InferredType() DataType {\n", exprType) 283 if define.Tags.Contains("HasType") { 284 fmt.Fprintf(g.w, " return e.Typ\n") 285 } else if isValueType(define) { 286 fmt.Fprintf(g.w, " return %sDataType\n", title(g.translateType(valueType(define)))) 287 } else { 288 fmt.Fprintf(g.w, " return AnyDataType\n") 289 } 290 fmt.Fprintf(g.w, "}\n\n") 291 } 292 293 // func (e *SomeExpr) String() string { 294 // var buf bytes.Buffer 295 // e.Format(&buf, 0) 296 // return buf.String() 297 // } 298 func (g *exprsGen) genStringFunc(define *lang.DefineExpr) { 299 exprType := fmt.Sprintf("%sExpr", define.Name) 300 301 fmt.Fprintf(g.w, "func (e *%s) String() string {\n", exprType) 302 fmt.Fprintf(g.w, " var buf bytes.Buffer\n") 303 fmt.Fprintf(g.w, " e.Format(&buf, 0)\n") 304 fmt.Fprintf(g.w, " return buf.String()\n") 305 fmt.Fprintf(g.w, "}\n\n") 306 } 307 308 // func (e *SomeExpr) Format(buf *bytes.Buffer, level int) { 309 // formatExpr(e, buf, level) 310 // } 311 func (g *exprsGen) genFormatFunc(define *lang.DefineExpr) { 312 exprType := fmt.Sprintf("%sExpr", define.Name) 313 314 fmt.Fprintf(g.w, "func (e *%s) Format(buf *bytes.Buffer, level int) {\n", exprType) 315 fmt.Fprintf(g.w, " formatExpr(e, buf, level)\n") 316 fmt.Fprintf(g.w, "}\n\n") 317 } 318 319 func (g *exprsGen) translateType(typ string) string { 320 switch typ { 321 case "Expr": 322 return typ 323 324 case "string": 325 return typ 326 327 case "int64": 328 return typ 329 } 330 331 if g.isValueOrSliceType(typ) { 332 return fmt.Sprintf("%sExpr", typ) 333 } 334 335 return fmt.Sprintf("*%sExpr", typ) 336 } 337 338 func (g *exprsGen) isValueOrSliceType(typ string) bool { 339 // Expr is built-in type, without explicit definition. 340 if typ == "Expr" { 341 return false 342 } 343 344 // Pass slices and value types by value. 345 define := g.compiled.LookupDefine(typ) 346 if define == nil { 347 panic(fmt.Sprintf("could not find define for type %s", typ)) 348 } 349 return isValueType(define) || isSliceType(define) 350 } 351 352 // isValueType returns true if the define statement is defining a value 353 // expression, which is a leaf expression that is equivalent to a primitive 354 // type like string or int. These types return non-nil for the Value function. 355 func isValueType(d *lang.DefineExpr) bool { 356 return d.Tags.Contains("Value") 357 } 358 359 // isSliceType returns true if the define statement is defining a slice 360 // expression, which is an expression that stores a slice of expressions of 361 // some other type, like []*RuleExpr or []TagExpr. 362 func isSliceType(d *lang.DefineExpr) bool { 363 return d.Tags.Contains("Slice") 364 } 365 366 // valueType returns the name of the primitive type which the defined type 367 // is equivalent to, like string or int. 368 func valueType(d *lang.DefineExpr) string { 369 if d.Fields[0].Name != "Value" { 370 panic(fmt.Sprintf("expected 'Value' field name, found %s", d.Fields[0].Name)) 371 } 372 return string(d.Fields[0].Type) 373 } 374 375 // sliceElementType returns the type of elements in the slice expression. 376 func sliceElementType(d *lang.DefineExpr) string { 377 if d.Fields[0].Name != "Element" { 378 panic(fmt.Sprintf("expected 'Element' field name, found %s", d.Fields[0].Name)) 379 } 380 return string(d.Fields[0].Type) 381 } 382 383 // hasSourceField returns true if the defined expression has a Src field that 384 // stores the original source information (file, line, pos). 385 func hasSourceField(d *lang.DefineExpr) bool { 386 return !isValueType(d) && !isSliceType(d) 387 } 388 389 // title returns the given string with its first letter capitalized. 390 func title(name string) string { 391 rune, size := utf8.DecodeRuneInString(name) 392 return fmt.Sprintf("%c%s", unicode.ToUpper(rune), name[size:]) 393 }