github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/opt/optgen/cmd/langgen/main.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package main
    12  
    13  import (
    14  	"bytes"
    15  	"flag"
    16  	"fmt"
    17  	"go/format"
    18  	"io"
    19  	"os"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/sql/opt/optgen/lang"
    22  	"github.com/cockroachdb/errors"
    23  )
    24  
    25  type genFunc func(compiled *lang.CompiledExpr, w io.Writer)
    26  
    27  var (
    28  	errInvalidArgCount     = errors.New("invalid number of arguments")
    29  	errUnrecognizedCommand = errors.New("unrecognized command")
    30  )
    31  
    32  var (
    33  	out = flag.String("out", "", "output file name of generated code")
    34  )
    35  
    36  // useGoFmt controls whether generated code is formatted by "go fmt" before
    37  // being output.
    38  const useGoFmt = true
    39  
    40  func main() {
    41  	flag.Usage = usage
    42  	flag.Parse()
    43  
    44  	args := flag.Args()
    45  	if len(args) < 2 {
    46  		flag.Usage()
    47  		fatal(errInvalidArgCount)
    48  	}
    49  
    50  	cmd := args[0]
    51  	switch cmd {
    52  	case "exprs":
    53  	case "ops":
    54  
    55  	default:
    56  		flag.Usage()
    57  		fatal(errUnrecognizedCommand)
    58  	}
    59  
    60  	sources := flag.Args()[1:]
    61  	compiler := lang.NewCompiler(sources...)
    62  	compiled := compiler.Compile()
    63  	if compiled == nil {
    64  		for i, err := range compiler.Errors() {
    65  			if i >= 10 {
    66  				fmt.Fprintf(os.Stderr, "... too many errors (%d more)\n", len(compiler.Errors()))
    67  				break
    68  			}
    69  
    70  			fmt.Fprintf(os.Stderr, "%v\n", err)
    71  		}
    72  		os.Exit(2)
    73  	}
    74  
    75  	var err error
    76  	switch cmd {
    77  	case "exprs":
    78  		var gen exprsGen
    79  		err = generate(compiled, *out, gen.generate)
    80  
    81  	case "ops":
    82  		err = generate(compiled, *out, generateOps)
    83  	}
    84  
    85  	if err != nil {
    86  		fatal(err)
    87  	}
    88  }
    89  
    90  // usage is a replacement usage function for the flags package.
    91  func usage() {
    92  	fmt.Fprintf(os.Stderr, "LangGen generates the AST for the Optgen language.\n\n")
    93  
    94  	fmt.Fprintf(os.Stderr, "LangGen uses the Optgen definition language to generate its own AST.\n")
    95  
    96  	fmt.Fprintf(os.Stderr, "Usage:\n")
    97  
    98  	fmt.Fprintf(os.Stderr, "\tlanggen [flags] command sources...\n\n")
    99  
   100  	fmt.Fprintf(os.Stderr, "The commands are:\n\n")
   101  	fmt.Fprintf(os.Stderr, "\texprs  generate expression definitions and functions\n")
   102  	fmt.Fprintf(os.Stderr, "\tops    generate operator definitions and functions\n")
   103  	fmt.Fprintf(os.Stderr, "\n")
   104  
   105  	fmt.Fprintf(os.Stderr, "Flags:\n")
   106  
   107  	flag.PrintDefaults()
   108  
   109  	fmt.Fprintf(os.Stderr, "\n")
   110  }
   111  
   112  func fatal(err error) {
   113  	fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
   114  	os.Exit(2)
   115  }
   116  
   117  func generate(compiled *lang.CompiledExpr, out string, genFunc genFunc) error {
   118  	var buf bytes.Buffer
   119  
   120  	buf.WriteString("// Code generated by langgen; DO NOT EDIT.\n\n")
   121  	fmt.Fprintf(&buf, "package lang\n\n")
   122  
   123  	genFunc(compiled, &buf)
   124  
   125  	var b []byte
   126  	var err error
   127  
   128  	if useGoFmt {
   129  		b, err = format.Source(buf.Bytes())
   130  		if err != nil {
   131  			// Write out incorrect source for easier debugging.
   132  			b = buf.Bytes()
   133  			err = fmt.Errorf("code formatting failed with Go parse error\n%s:%s", out, err)
   134  		}
   135  	} else {
   136  		b = buf.Bytes()
   137  	}
   138  
   139  	var writer io.Writer
   140  	if out != "" {
   141  		file, err := os.Create(out)
   142  		if err != nil {
   143  			fatal(err)
   144  		}
   145  
   146  		defer file.Close()
   147  		writer = file
   148  	} else {
   149  		writer = os.Stderr
   150  	}
   151  
   152  	if err != nil {
   153  		// Ignore any write error if another error already occurred.
   154  		_, _ = writer.Write(b)
   155  	} else {
   156  		_, err = writer.Write(b)
   157  	}
   158  
   159  	return err
   160  }