github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/cmd/avrogen/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"github.com/hamba/avro/v2"
    14  	"github.com/hamba/avro/v2/gen"
    15  	"golang.org/x/tools/imports"
    16  )
    17  
    18  type config struct {
    19  	TemplateFileName string
    20  
    21  	Pkg         string
    22  	Out         string
    23  	Tags        string
    24  	FullName    bool
    25  	Encoders    bool
    26  	StrictTypes bool
    27  	Initialisms string
    28  }
    29  
    30  func main() {
    31  	os.Exit(realMain(os.Args, os.Stdout, os.Stderr))
    32  }
    33  
    34  func realMain(args []string, stdout, stderr io.Writer) int {
    35  	var cfg config
    36  	flgs := flag.NewFlagSet("avrogen", flag.ExitOnError)
    37  	flgs.SetOutput(stderr)
    38  	flgs.StringVar(&cfg.Pkg, "pkg", "", "The package name of the output file.")
    39  	flgs.StringVar(&cfg.Out, "o", "", "The output file path to write to instead of stdout.")
    40  	flgs.StringVar(&cfg.Tags, "tags", "", "The additional field tags <tag-name>:{snake|camel|upper-camel|kebab}>[,...]")
    41  	flgs.BoolVar(&cfg.FullName, "fullname", false, "Use the full name of the Record schema to create the struct name.")
    42  	flgs.BoolVar(&cfg.Encoders, "encoders", false, "Generate encoders for the structs.")
    43  	flgs.BoolVar(&cfg.StrictTypes, "strict-types", false, "Use strict type sizes (e.g. int32) during generation.")
    44  	flgs.StringVar(&cfg.Initialisms, "initialisms", "", "Custom initialisms <VAL>[,...] for struct and field names.")
    45  	flgs.StringVar(&cfg.TemplateFileName, "template-filename", "", "Override output template with one loaded from file.")
    46  	flgs.Usage = func() {
    47  		_, _ = fmt.Fprintln(stderr, "Usage: avrogen [options] schemas")
    48  		_, _ = fmt.Fprintln(stderr, "Options:")
    49  		flgs.PrintDefaults()
    50  	}
    51  	if err := flgs.Parse(args[1:]); err != nil {
    52  		return 1
    53  	}
    54  
    55  	if err := validateOpts(flgs.NArg(), cfg); err != nil {
    56  		_, _ = fmt.Fprintln(stderr, "Error: "+err.Error())
    57  		return 1
    58  	}
    59  
    60  	tags, err := parseTags(cfg.Tags)
    61  	if err != nil {
    62  		_, _ = fmt.Fprintln(stderr, "Error: "+err.Error())
    63  		return 1
    64  	}
    65  
    66  	initialisms, err := parseInitialisms(cfg.Initialisms)
    67  	if err != nil {
    68  		_, _ = fmt.Fprintln(stderr, "Error: "+err.Error())
    69  		return 1
    70  	}
    71  
    72  	template, err := loadTemplate(cfg.TemplateFileName)
    73  	if err != nil {
    74  		_, _ = fmt.Fprintln(stderr, "Error: "+err.Error())
    75  		return 1
    76  	}
    77  
    78  	opts := []gen.OptsFunc{
    79  		gen.WithFullName(cfg.FullName),
    80  		gen.WithEncoders(cfg.Encoders),
    81  		gen.WithInitialisms(initialisms),
    82  		gen.WithTemplate(string(template)),
    83  		gen.WithStrictTypes(cfg.StrictTypes),
    84  	}
    85  	g := gen.NewGenerator(cfg.Pkg, tags, opts...)
    86  	for _, file := range flgs.Args() {
    87  		schema, err := avro.ParseFiles(filepath.Clean(file))
    88  		if err != nil {
    89  			_, _ = fmt.Fprintf(stderr, "Error: %v\n", err)
    90  			return 2
    91  		}
    92  		g.Parse(schema)
    93  	}
    94  
    95  	var buf bytes.Buffer
    96  	if err = g.Write(&buf); err != nil {
    97  		_, _ = fmt.Fprintf(stderr, "Error: could not generate code: %v\n", err)
    98  		return 3
    99  	}
   100  	formatted, err := imports.Process("", buf.Bytes(), nil)
   101  	if err != nil {
   102  		_ = writeOut(cfg.Out, stdout, buf.Bytes())
   103  		_, _ = fmt.Fprintf(stderr, "Error: generated code could not be formatted: %v\n", err)
   104  		return 3
   105  	}
   106  
   107  	err = writeOut(cfg.Out, stdout, formatted)
   108  	if err != nil {
   109  		_, _ = fmt.Fprintf(stderr, "Error: %v\n", err)
   110  		return 4
   111  	}
   112  	return 0
   113  }
   114  
   115  func writeOut(filename string, stdout io.Writer, bytes []byte) error {
   116  	writer := stdout
   117  	if filename != "" {
   118  		file, err := os.Create(filepath.Clean(filename))
   119  		if err != nil {
   120  			return fmt.Errorf("could not create output file: %w", err)
   121  		}
   122  		defer func() { _ = file.Close() }()
   123  
   124  		writer = file
   125  	}
   126  
   127  	if _, err := writer.Write(bytes); err != nil {
   128  		return fmt.Errorf("could not write code: %w", err)
   129  	}
   130  	return nil
   131  }
   132  
   133  func validateOpts(nargs int, cfg config) error {
   134  	if nargs < 1 {
   135  		return errors.New("at least one schema is required")
   136  	}
   137  
   138  	if cfg.Pkg == "" {
   139  		return errors.New("a package is required")
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  func parseTags(raw string) (map[string]gen.TagStyle, error) {
   146  	if raw == "" {
   147  		return map[string]gen.TagStyle{}, nil
   148  	}
   149  
   150  	result := map[string]gen.TagStyle{}
   151  	for _, tag := range strings.Split(raw, ",") {
   152  		parts := strings.Split(tag, ":")
   153  		switch {
   154  		case len(parts) != 2:
   155  			return nil, fmt.Errorf("%q is not a valid tag, should be in the formet \"tag:style\"", tag)
   156  		case parts[0] == "":
   157  			return nil, fmt.Errorf("tag name is required in %q", tag)
   158  		}
   159  
   160  		var style gen.TagStyle
   161  		switch strings.ToLower(parts[1]) {
   162  		case string(gen.UpperCamel):
   163  			style = gen.UpperCamel
   164  		case string(gen.Camel):
   165  			style = gen.Camel
   166  		case string(gen.Kebab):
   167  			style = gen.Kebab
   168  		case string(gen.Snake):
   169  			style = gen.Snake
   170  		case string(gen.Original):
   171  			style = gen.Original
   172  		default:
   173  			return nil, fmt.Errorf("style %q is invalid in %q", parts[1], tag)
   174  		}
   175  		result[parts[0]] = style
   176  	}
   177  	return result, nil
   178  }
   179  
   180  func parseInitialisms(raw string) ([]string, error) {
   181  	if raw == "" {
   182  		return []string{}, nil
   183  	}
   184  
   185  	result := []string{}
   186  	for _, initialism := range strings.Split(raw, ",") {
   187  		if initialism != strings.ToUpper(initialism) {
   188  			return nil, fmt.Errorf("initialism %q must be fully in upper case", initialism)
   189  		}
   190  		result = append(result, initialism)
   191  	}
   192  
   193  	return result, nil
   194  }
   195  
   196  func loadTemplate(templateFileName string) ([]byte, error) {
   197  	if templateFileName == "" {
   198  		return nil, nil
   199  	}
   200  	return os.ReadFile(filepath.Clean(templateFileName))
   201  }