github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/gen/gen.go (about)

     1  // Package gen allows generating Go structs from avro schemas.
     2  package gen
     3  
     4  import (
     5  	"bytes"
     6  	_ "embed"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"maps"
    11  	"strings"
    12  	"text/template"
    13  
    14  	"github.com/ettle/strcase"
    15  	"github.com/hamba/avro/v2"
    16  	"golang.org/x/tools/imports"
    17  )
    18  
    19  // Config configures the code generation.
    20  type Config struct {
    21  	PackageName string
    22  	Tags        map[string]TagStyle
    23  	FullName    bool
    24  	Encoders    bool
    25  	StrictTypes bool
    26  	Initialisms []string
    27  }
    28  
    29  // TagStyle defines the styling for a tag.
    30  type TagStyle string
    31  
    32  const (
    33  	// Original is a style like whAtEVer_IS_InthEInpuT.
    34  	Original TagStyle = "original"
    35  	// Snake is a style like im_written_in_snake_case.
    36  	Snake TagStyle = "snake"
    37  	// Camel is a style like imWrittenInCamelCase.
    38  	Camel TagStyle = "camel"
    39  	// Kebab is a style like im-written-in-kebab-case.
    40  	Kebab TagStyle = "kebab"
    41  	// UpperCamel is a style like ImWrittenInUpperCamel.
    42  	UpperCamel TagStyle = "upper-camel"
    43  )
    44  
    45  //go:embed output_template.tmpl
    46  var outputTemplate string
    47  
    48  var (
    49  	primitiveMappings = map[avro.Type]string{
    50  		"string":  "string",
    51  		"bytes":   "[]byte",
    52  		"int":     "int",
    53  		"long":    "int64",
    54  		"float":   "float32",
    55  		"double":  "float64",
    56  		"boolean": "bool",
    57  	}
    58  	strictTypeMappings = map[string]string{
    59  		"int": "int32",
    60  	}
    61  )
    62  
    63  // Struct generates Go structs based on the schema and writes them to w.
    64  func Struct(s string, w io.Writer, cfg Config) error {
    65  	schema, err := avro.Parse(s)
    66  	if err != nil {
    67  		return err
    68  	}
    69  	return StructFromSchema(schema, w, cfg)
    70  }
    71  
    72  // StructFromSchema generates Go structs based on the schema and writes them to w.
    73  func StructFromSchema(schema avro.Schema, w io.Writer, cfg Config) error {
    74  	rec, ok := schema.(*avro.RecordSchema)
    75  	if !ok {
    76  		return errors.New("can only generate Go code from Record Schemas")
    77  	}
    78  
    79  	opts := []OptsFunc{
    80  		WithFullName(cfg.FullName),
    81  		WithEncoders(cfg.Encoders),
    82  		WithInitialisms(cfg.Initialisms),
    83  		WithStrictTypes(cfg.StrictTypes),
    84  	}
    85  	g := NewGenerator(strcase.ToSnake(cfg.PackageName), cfg.Tags, opts...)
    86  	g.Parse(rec)
    87  
    88  	buf := &bytes.Buffer{}
    89  	if err := g.Write(buf); err != nil {
    90  		return err
    91  	}
    92  
    93  	formatted, err := imports.Process("", buf.Bytes(), nil)
    94  	if err != nil {
    95  		_, _ = w.Write(buf.Bytes())
    96  		return fmt.Errorf("generated code could not be formatted: %w", err)
    97  	}
    98  
    99  	_, err = w.Write(formatted)
   100  	return err
   101  }
   102  
   103  // OptsFunc is a function that configures a generator.
   104  type OptsFunc func(*Generator)
   105  
   106  // WithFullName configures the generator to use the full name of a record
   107  // when creating the struct name.
   108  func WithFullName(b bool) OptsFunc {
   109  	return func(g *Generator) {
   110  		g.fullName = b
   111  	}
   112  }
   113  
   114  // WithEncoders configures the generator to generate schema and encoders on
   115  // all objects.
   116  func WithEncoders(b bool) OptsFunc {
   117  	return func(g *Generator) {
   118  		g.encoders = b
   119  		if b {
   120  			g.thirdPartyImports = append(g.thirdPartyImports, "github.com/hamba/avro/v2")
   121  		}
   122  	}
   123  }
   124  
   125  // WithInitialisms configures the generator to use additional custom initialisms
   126  // when styling struct and field names.
   127  func WithInitialisms(ss []string) OptsFunc {
   128  	return func(g *Generator) {
   129  		g.initialisms = ss
   130  	}
   131  }
   132  
   133  // WithTemplate configures the generator to use a custom template provided by the user.
   134  func WithTemplate(template string) OptsFunc {
   135  	return func(g *Generator) {
   136  		if template == "" {
   137  			return
   138  		}
   139  		g.template = template
   140  	}
   141  }
   142  
   143  // WithStrictTypes configures the generator to use strict type sizes.
   144  func WithStrictTypes(b bool) OptsFunc {
   145  	return func(g *Generator) {
   146  		g.strictTypes = b
   147  	}
   148  }
   149  
   150  // Generator generates Go structs from schemas.
   151  type Generator struct {
   152  	template    string
   153  	pkg         string
   154  	tags        map[string]TagStyle
   155  	fullName    bool
   156  	encoders    bool
   157  	strictTypes bool
   158  	initialisms []string
   159  
   160  	imports           []string
   161  	thirdPartyImports []string
   162  	typedefs          []typedef
   163  
   164  	nameCaser *strcase.Caser
   165  }
   166  
   167  // NewGenerator returns a generator.
   168  func NewGenerator(pkg string, tags map[string]TagStyle, opts ...OptsFunc) *Generator {
   169  	clonedTags := maps.Clone(tags)
   170  	delete(clonedTags, "avro")
   171  
   172  	g := &Generator{
   173  		template: outputTemplate,
   174  		pkg:      pkg,
   175  		tags:     clonedTags,
   176  	}
   177  
   178  	for _, opt := range opts {
   179  		opt(g)
   180  	}
   181  
   182  	initialisms := map[string]bool{}
   183  	for _, v := range g.initialisms {
   184  		initialisms[v] = true
   185  	}
   186  
   187  	g.nameCaser = strcase.NewCaser(
   188  		true, // use standard Golint's initialisms
   189  		initialisms,
   190  		nil, // use default word split function
   191  	)
   192  
   193  	return g
   194  }
   195  
   196  // Reset reset the generator.
   197  func (g *Generator) Reset() {
   198  	g.imports = g.imports[:0]
   199  	g.thirdPartyImports = g.thirdPartyImports[:0]
   200  	g.typedefs = g.typedefs[:0]
   201  }
   202  
   203  // Parse parses an avro schema into Go types.
   204  func (g *Generator) Parse(schema avro.Schema) {
   205  	_ = g.generate(schema)
   206  }
   207  
   208  func (g *Generator) generate(schema avro.Schema) string {
   209  	switch s := schema.(type) {
   210  	case *avro.RefSchema:
   211  		return g.resolveRefSchema(s)
   212  	case *avro.RecordSchema:
   213  		return g.resolveRecordSchema(s)
   214  	case *avro.PrimitiveSchema:
   215  		typ := primitiveMappings[s.Type()]
   216  		if ls := s.Logical(); ls != nil {
   217  			typ = g.resolveLogicalSchema(ls.Type())
   218  		}
   219  		if g.strictTypes {
   220  			if newTyp, ok := strictTypeMappings[typ]; ok {
   221  				typ = newTyp
   222  			}
   223  		}
   224  		return typ
   225  	case *avro.ArraySchema:
   226  		return "[]" + g.generate(s.Items())
   227  	case *avro.EnumSchema:
   228  		return "string"
   229  	case *avro.FixedSchema:
   230  		typ := fmt.Sprintf("[%d]byte", s.Size())
   231  		if ls := s.Logical(); ls != nil {
   232  			typ = g.resolveLogicalSchema(ls.Type())
   233  		}
   234  		return typ
   235  	case *avro.MapSchema:
   236  		return "map[string]" + g.generate(s.Values())
   237  	case *avro.UnionSchema:
   238  		return g.resolveUnionTypes(s)
   239  	default:
   240  		return ""
   241  	}
   242  }
   243  
   244  func (g *Generator) resolveTypeName(s avro.NamedSchema) string {
   245  	if g.fullName {
   246  		return g.nameCaser.ToPascal(s.FullName())
   247  	}
   248  	return g.nameCaser.ToPascal(s.Name())
   249  }
   250  
   251  func (g *Generator) resolveRecordSchema(schema *avro.RecordSchema) string {
   252  	fields := make([]field, len(schema.Fields()))
   253  	for i, f := range schema.Fields() {
   254  		typ := g.generate(f.Type())
   255  		fields[i] = g.newField(g.nameCaser.ToPascal(f.Name()), typ, f.Doc(), f.Name())
   256  	}
   257  
   258  	typeName := g.resolveTypeName(schema)
   259  	if !g.hasTypeDef(typeName) {
   260  		g.typedefs = append(g.typedefs, newType(typeName, fields, schema.String()))
   261  	}
   262  	return typeName
   263  }
   264  
   265  func (g *Generator) hasTypeDef(name string) bool {
   266  	for _, def := range g.typedefs {
   267  		if def.Name != name {
   268  			continue
   269  		}
   270  		return true
   271  	}
   272  	return false
   273  }
   274  
   275  func (g *Generator) resolveRefSchema(s *avro.RefSchema) string {
   276  	if sx, ok := s.Schema().(*avro.RecordSchema); ok {
   277  		return g.resolveTypeName(sx)
   278  	}
   279  	return g.generate(s.Schema())
   280  }
   281  
   282  func (g *Generator) resolveUnionTypes(s *avro.UnionSchema) string {
   283  	types := make([]string, 0, len(s.Types()))
   284  	for _, elem := range s.Types() {
   285  		if _, ok := elem.(*avro.NullSchema); ok {
   286  			continue
   287  		}
   288  		types = append(types, g.generate(elem))
   289  	}
   290  	if s.Nullable() {
   291  		return "*" + types[0]
   292  	}
   293  	return "any"
   294  }
   295  
   296  func (g *Generator) resolveLogicalSchema(logicalType avro.LogicalType) string {
   297  	var typ string
   298  	switch logicalType {
   299  	case "date", "timestamp-millis", "timestamp-micros":
   300  		typ = "time.Time"
   301  	case "time-millis", "time-micros":
   302  		typ = "time.Duration"
   303  	case "decimal":
   304  		typ = "*big.Rat"
   305  	case "duration":
   306  		typ = "avro.LogicalDuration"
   307  	case "uuid":
   308  		typ = "string"
   309  	}
   310  	if strings.Contains(typ, "time") {
   311  		g.addImport("time")
   312  	}
   313  	if strings.Contains(typ, "big") {
   314  		g.addImport("math/big")
   315  	}
   316  	if strings.Contains(typ, "avro") {
   317  		g.addThirdPartyImport("github.com/hamba/avro/v2")
   318  	}
   319  	return typ
   320  }
   321  
   322  func (g *Generator) newField(name, typ, avroFieldDoc, avroFieldName string) field {
   323  	return field{
   324  		Name:          name,
   325  		Type:          typ,
   326  		AvroFieldName: avroFieldName,
   327  		AvroFieldDoc:  avroFieldDoc,
   328  		Tags:          g.tags,
   329  	}
   330  }
   331  
   332  func (g *Generator) addImport(pkg string) {
   333  	for _, p := range g.imports {
   334  		if p == pkg {
   335  			return
   336  		}
   337  	}
   338  	g.imports = append(g.imports, pkg)
   339  }
   340  
   341  func (g *Generator) addThirdPartyImport(pkg string) {
   342  	for _, p := range g.thirdPartyImports {
   343  		if p == pkg {
   344  			return
   345  		}
   346  	}
   347  	g.thirdPartyImports = append(g.thirdPartyImports, pkg)
   348  }
   349  
   350  // Write writes Go code from the parsed schemas.
   351  func (g *Generator) Write(w io.Writer) error {
   352  	parsed, err := template.New("out").
   353  		Funcs(template.FuncMap{
   354  			"kebab":      strcase.ToKebab,
   355  			"upperCamel": strcase.ToPascal,
   356  			"camel":      strcase.ToCamel,
   357  			"snake":      strcase.ToSnake,
   358  		}).
   359  		Parse(g.template)
   360  	if err != nil {
   361  		return err
   362  	}
   363  
   364  	data := struct {
   365  		WithEncoders      bool
   366  		PackageName       string
   367  		Imports           []string
   368  		ThirdPartyImports []string
   369  		Typedefs          []typedef
   370  	}{
   371  		WithEncoders: g.encoders,
   372  		PackageName:  g.pkg,
   373  		Imports:      append(g.imports, g.thirdPartyImports...),
   374  		Typedefs:     g.typedefs,
   375  	}
   376  	return parsed.Execute(w, data)
   377  }
   378  
   379  type typedef struct {
   380  	Name   string
   381  	Fields []field
   382  	Schema string
   383  }
   384  
   385  func newType(name string, fields []field, schema string) typedef {
   386  	return typedef{
   387  		Name:   name,
   388  		Fields: fields,
   389  		Schema: schema,
   390  	}
   391  }
   392  
   393  type field struct {
   394  	Name          string
   395  	Type          string
   396  	AvroFieldName string
   397  	AvroFieldDoc  string
   398  	Tags          map[string]TagStyle
   399  }