github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/gen/gen.go (about) 1 // Package gen allows generating Go structs from avro schemas. 2 package gen 3 4 import ( 5 "bytes" 6 _ "embed" 7 "errors" 8 "fmt" 9 "io" 10 "maps" 11 "strings" 12 "text/template" 13 14 "github.com/ettle/strcase" 15 "github.com/hamba/avro/v2" 16 "golang.org/x/tools/imports" 17 ) 18 19 // Config configures the code generation. 20 type Config struct { 21 PackageName string 22 Tags map[string]TagStyle 23 FullName bool 24 Encoders bool 25 StrictTypes bool 26 Initialisms []string 27 } 28 29 // TagStyle defines the styling for a tag. 30 type TagStyle string 31 32 const ( 33 // Original is a style like whAtEVer_IS_InthEInpuT. 34 Original TagStyle = "original" 35 // Snake is a style like im_written_in_snake_case. 36 Snake TagStyle = "snake" 37 // Camel is a style like imWrittenInCamelCase. 38 Camel TagStyle = "camel" 39 // Kebab is a style like im-written-in-kebab-case. 40 Kebab TagStyle = "kebab" 41 // UpperCamel is a style like ImWrittenInUpperCamel. 42 UpperCamel TagStyle = "upper-camel" 43 ) 44 45 //go:embed output_template.tmpl 46 var outputTemplate string 47 48 var ( 49 primitiveMappings = map[avro.Type]string{ 50 "string": "string", 51 "bytes": "[]byte", 52 "int": "int", 53 "long": "int64", 54 "float": "float32", 55 "double": "float64", 56 "boolean": "bool", 57 } 58 strictTypeMappings = map[string]string{ 59 "int": "int32", 60 } 61 ) 62 63 // Struct generates Go structs based on the schema and writes them to w. 64 func Struct(s string, w io.Writer, cfg Config) error { 65 schema, err := avro.Parse(s) 66 if err != nil { 67 return err 68 } 69 return StructFromSchema(schema, w, cfg) 70 } 71 72 // StructFromSchema generates Go structs based on the schema and writes them to w. 73 func StructFromSchema(schema avro.Schema, w io.Writer, cfg Config) error { 74 rec, ok := schema.(*avro.RecordSchema) 75 if !ok { 76 return errors.New("can only generate Go code from Record Schemas") 77 } 78 79 opts := []OptsFunc{ 80 WithFullName(cfg.FullName), 81 WithEncoders(cfg.Encoders), 82 WithInitialisms(cfg.Initialisms), 83 WithStrictTypes(cfg.StrictTypes), 84 } 85 g := NewGenerator(strcase.ToSnake(cfg.PackageName), cfg.Tags, opts...) 86 g.Parse(rec) 87 88 buf := &bytes.Buffer{} 89 if err := g.Write(buf); err != nil { 90 return err 91 } 92 93 formatted, err := imports.Process("", buf.Bytes(), nil) 94 if err != nil { 95 _, _ = w.Write(buf.Bytes()) 96 return fmt.Errorf("generated code could not be formatted: %w", err) 97 } 98 99 _, err = w.Write(formatted) 100 return err 101 } 102 103 // OptsFunc is a function that configures a generator. 104 type OptsFunc func(*Generator) 105 106 // WithFullName configures the generator to use the full name of a record 107 // when creating the struct name. 108 func WithFullName(b bool) OptsFunc { 109 return func(g *Generator) { 110 g.fullName = b 111 } 112 } 113 114 // WithEncoders configures the generator to generate schema and encoders on 115 // all objects. 116 func WithEncoders(b bool) OptsFunc { 117 return func(g *Generator) { 118 g.encoders = b 119 if b { 120 g.thirdPartyImports = append(g.thirdPartyImports, "github.com/hamba/avro/v2") 121 } 122 } 123 } 124 125 // WithInitialisms configures the generator to use additional custom initialisms 126 // when styling struct and field names. 127 func WithInitialisms(ss []string) OptsFunc { 128 return func(g *Generator) { 129 g.initialisms = ss 130 } 131 } 132 133 // WithTemplate configures the generator to use a custom template provided by the user. 134 func WithTemplate(template string) OptsFunc { 135 return func(g *Generator) { 136 if template == "" { 137 return 138 } 139 g.template = template 140 } 141 } 142 143 // WithStrictTypes configures the generator to use strict type sizes. 144 func WithStrictTypes(b bool) OptsFunc { 145 return func(g *Generator) { 146 g.strictTypes = b 147 } 148 } 149 150 // Generator generates Go structs from schemas. 151 type Generator struct { 152 template string 153 pkg string 154 tags map[string]TagStyle 155 fullName bool 156 encoders bool 157 strictTypes bool 158 initialisms []string 159 160 imports []string 161 thirdPartyImports []string 162 typedefs []typedef 163 164 nameCaser *strcase.Caser 165 } 166 167 // NewGenerator returns a generator. 168 func NewGenerator(pkg string, tags map[string]TagStyle, opts ...OptsFunc) *Generator { 169 clonedTags := maps.Clone(tags) 170 delete(clonedTags, "avro") 171 172 g := &Generator{ 173 template: outputTemplate, 174 pkg: pkg, 175 tags: clonedTags, 176 } 177 178 for _, opt := range opts { 179 opt(g) 180 } 181 182 initialisms := map[string]bool{} 183 for _, v := range g.initialisms { 184 initialisms[v] = true 185 } 186 187 g.nameCaser = strcase.NewCaser( 188 true, // use standard Golint's initialisms 189 initialisms, 190 nil, // use default word split function 191 ) 192 193 return g 194 } 195 196 // Reset reset the generator. 197 func (g *Generator) Reset() { 198 g.imports = g.imports[:0] 199 g.thirdPartyImports = g.thirdPartyImports[:0] 200 g.typedefs = g.typedefs[:0] 201 } 202 203 // Parse parses an avro schema into Go types. 204 func (g *Generator) Parse(schema avro.Schema) { 205 _ = g.generate(schema) 206 } 207 208 func (g *Generator) generate(schema avro.Schema) string { 209 switch s := schema.(type) { 210 case *avro.RefSchema: 211 return g.resolveRefSchema(s) 212 case *avro.RecordSchema: 213 return g.resolveRecordSchema(s) 214 case *avro.PrimitiveSchema: 215 typ := primitiveMappings[s.Type()] 216 if ls := s.Logical(); ls != nil { 217 typ = g.resolveLogicalSchema(ls.Type()) 218 } 219 if g.strictTypes { 220 if newTyp, ok := strictTypeMappings[typ]; ok { 221 typ = newTyp 222 } 223 } 224 return typ 225 case *avro.ArraySchema: 226 return "[]" + g.generate(s.Items()) 227 case *avro.EnumSchema: 228 return "string" 229 case *avro.FixedSchema: 230 typ := fmt.Sprintf("[%d]byte", s.Size()) 231 if ls := s.Logical(); ls != nil { 232 typ = g.resolveLogicalSchema(ls.Type()) 233 } 234 return typ 235 case *avro.MapSchema: 236 return "map[string]" + g.generate(s.Values()) 237 case *avro.UnionSchema: 238 return g.resolveUnionTypes(s) 239 default: 240 return "" 241 } 242 } 243 244 func (g *Generator) resolveTypeName(s avro.NamedSchema) string { 245 if g.fullName { 246 return g.nameCaser.ToPascal(s.FullName()) 247 } 248 return g.nameCaser.ToPascal(s.Name()) 249 } 250 251 func (g *Generator) resolveRecordSchema(schema *avro.RecordSchema) string { 252 fields := make([]field, len(schema.Fields())) 253 for i, f := range schema.Fields() { 254 typ := g.generate(f.Type()) 255 fields[i] = g.newField(g.nameCaser.ToPascal(f.Name()), typ, f.Doc(), f.Name()) 256 } 257 258 typeName := g.resolveTypeName(schema) 259 if !g.hasTypeDef(typeName) { 260 g.typedefs = append(g.typedefs, newType(typeName, fields, schema.String())) 261 } 262 return typeName 263 } 264 265 func (g *Generator) hasTypeDef(name string) bool { 266 for _, def := range g.typedefs { 267 if def.Name != name { 268 continue 269 } 270 return true 271 } 272 return false 273 } 274 275 func (g *Generator) resolveRefSchema(s *avro.RefSchema) string { 276 if sx, ok := s.Schema().(*avro.RecordSchema); ok { 277 return g.resolveTypeName(sx) 278 } 279 return g.generate(s.Schema()) 280 } 281 282 func (g *Generator) resolveUnionTypes(s *avro.UnionSchema) string { 283 types := make([]string, 0, len(s.Types())) 284 for _, elem := range s.Types() { 285 if _, ok := elem.(*avro.NullSchema); ok { 286 continue 287 } 288 types = append(types, g.generate(elem)) 289 } 290 if s.Nullable() { 291 return "*" + types[0] 292 } 293 return "any" 294 } 295 296 func (g *Generator) resolveLogicalSchema(logicalType avro.LogicalType) string { 297 var typ string 298 switch logicalType { 299 case "date", "timestamp-millis", "timestamp-micros": 300 typ = "time.Time" 301 case "time-millis", "time-micros": 302 typ = "time.Duration" 303 case "decimal": 304 typ = "*big.Rat" 305 case "duration": 306 typ = "avro.LogicalDuration" 307 case "uuid": 308 typ = "string" 309 } 310 if strings.Contains(typ, "time") { 311 g.addImport("time") 312 } 313 if strings.Contains(typ, "big") { 314 g.addImport("math/big") 315 } 316 if strings.Contains(typ, "avro") { 317 g.addThirdPartyImport("github.com/hamba/avro/v2") 318 } 319 return typ 320 } 321 322 func (g *Generator) newField(name, typ, avroFieldDoc, avroFieldName string) field { 323 return field{ 324 Name: name, 325 Type: typ, 326 AvroFieldName: avroFieldName, 327 AvroFieldDoc: avroFieldDoc, 328 Tags: g.tags, 329 } 330 } 331 332 func (g *Generator) addImport(pkg string) { 333 for _, p := range g.imports { 334 if p == pkg { 335 return 336 } 337 } 338 g.imports = append(g.imports, pkg) 339 } 340 341 func (g *Generator) addThirdPartyImport(pkg string) { 342 for _, p := range g.thirdPartyImports { 343 if p == pkg { 344 return 345 } 346 } 347 g.thirdPartyImports = append(g.thirdPartyImports, pkg) 348 } 349 350 // Write writes Go code from the parsed schemas. 351 func (g *Generator) Write(w io.Writer) error { 352 parsed, err := template.New("out"). 353 Funcs(template.FuncMap{ 354 "kebab": strcase.ToKebab, 355 "upperCamel": strcase.ToPascal, 356 "camel": strcase.ToCamel, 357 "snake": strcase.ToSnake, 358 }). 359 Parse(g.template) 360 if err != nil { 361 return err 362 } 363 364 data := struct { 365 WithEncoders bool 366 PackageName string 367 Imports []string 368 ThirdPartyImports []string 369 Typedefs []typedef 370 }{ 371 WithEncoders: g.encoders, 372 PackageName: g.pkg, 373 Imports: append(g.imports, g.thirdPartyImports...), 374 Typedefs: g.typedefs, 375 } 376 return parsed.Execute(w, data) 377 } 378 379 type typedef struct { 380 Name string 381 Fields []field 382 Schema string 383 } 384 385 func newType(name string, fields []field, schema string) typedef { 386 return typedef{ 387 Name: name, 388 Fields: fields, 389 Schema: schema, 390 } 391 } 392 393 type field struct { 394 Name string 395 Type string 396 AvroFieldName string 397 AvroFieldDoc string 398 Tags map[string]TagStyle 399 }