github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/schema_parse.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math"
     7  	"os"
     8  	"path/filepath"
     9  	"strings"
    10  
    11  	jsoniter "github.com/json-iterator/go"
    12  	"github.com/mitchellh/mapstructure"
    13  )
    14  
    15  // DefaultSchemaCache is the default cache for schemas.
    16  var DefaultSchemaCache = &SchemaCache{}
    17  
    18  // Parse parses a schema string.
    19  func Parse(schema string) (Schema, error) {
    20  	return ParseBytes([]byte(schema))
    21  }
    22  
    23  // ParseWithCache parses a schema string using the given namespace and schema cache.
    24  func ParseWithCache(schema, namespace string, cache *SchemaCache) (Schema, error) {
    25  	return ParseBytesWithCache([]byte(schema), namespace, cache)
    26  }
    27  
    28  // MustParse parses a schema string, panicing if there is an error.
    29  func MustParse(schema string) Schema {
    30  	parsed, err := Parse(schema)
    31  	if err != nil {
    32  		panic(err)
    33  	}
    34  
    35  	return parsed
    36  }
    37  
    38  // ParseFiles parses the schemas in the files, in the order they appear, returning the last schema.
    39  //
    40  // This is useful when your schemas rely on other schemas.
    41  func ParseFiles(paths ...string) (Schema, error) {
    42  	var schema Schema
    43  	for _, path := range paths {
    44  		s, err := os.ReadFile(filepath.Clean(path))
    45  		if err != nil {
    46  			return nil, err
    47  		}
    48  
    49  		schema, err = Parse(string(s))
    50  		if err != nil {
    51  			return nil, err
    52  		}
    53  	}
    54  
    55  	return schema, nil
    56  }
    57  
    58  // ParseBytes parses a schema byte slice.
    59  func ParseBytes(schema []byte) (Schema, error) {
    60  	return ParseBytesWithCache(schema, "", DefaultSchemaCache)
    61  }
    62  
    63  // ParseBytesWithCache parses a schema byte slice using the given namespace and schema cache.
    64  func ParseBytesWithCache(schema []byte, namespace string, cache *SchemaCache) (Schema, error) {
    65  	var json any
    66  	if err := jsoniter.Unmarshal(schema, &json); err != nil {
    67  		json = string(schema)
    68  	}
    69  
    70  	seen := seenCache{}
    71  	s, err := parseType(namespace, json, seen, cache)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	return derefSchema(s), nil
    76  }
    77  
    78  func parseType(namespace string, v any, seen seenCache, cache *SchemaCache) (Schema, error) {
    79  	switch val := v.(type) {
    80  	case nil:
    81  		return &NullSchema{}, nil
    82  
    83  	case string:
    84  		return parsePrimitiveType(namespace, val, cache)
    85  
    86  	case map[string]any:
    87  		return parseComplexType(namespace, val, seen, cache)
    88  
    89  	case []any:
    90  		return parseUnion(namespace, val, seen, cache)
    91  	}
    92  
    93  	return nil, fmt.Errorf("avro: unknown type: %v", v)
    94  }
    95  
    96  func parsePrimitiveType(namespace, s string, cache *SchemaCache) (Schema, error) {
    97  	typ := Type(s)
    98  	switch typ {
    99  	case Null:
   100  		return &NullSchema{}, nil
   101  
   102  	case String, Bytes, Int, Long, Float, Double, Boolean:
   103  		return parsePrimitive(typ, nil)
   104  
   105  	default:
   106  		schema := cache.Get(fullName(namespace, s))
   107  		if schema != nil {
   108  			return schema, nil
   109  		}
   110  
   111  		return nil, fmt.Errorf("avro: unknown type: %s", s)
   112  	}
   113  }
   114  
   115  func parseComplexType(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   116  	if val, ok := m["type"].([]any); ok {
   117  		return parseUnion(namespace, val, seen, cache)
   118  	}
   119  
   120  	str, ok := m["type"].(string)
   121  	if !ok {
   122  		return nil, fmt.Errorf("avro: unknown type: %+v", m)
   123  	}
   124  	typ := Type(str)
   125  
   126  	switch typ {
   127  	case Null:
   128  		return &NullSchema{}, nil
   129  
   130  	case String, Bytes, Int, Long, Float, Double, Boolean:
   131  		return parsePrimitive(typ, m)
   132  
   133  	case Record, Error:
   134  		return parseRecord(typ, namespace, m, seen, cache)
   135  
   136  	case Enum:
   137  		return parseEnum(namespace, m, seen, cache)
   138  
   139  	case Array:
   140  		return parseArray(namespace, m, seen, cache)
   141  
   142  	case Map:
   143  		return parseMap(namespace, m, seen, cache)
   144  
   145  	case Fixed:
   146  		return parseFixed(namespace, m, seen, cache)
   147  
   148  	default:
   149  		return parseType(namespace, string(typ), seen, cache)
   150  	}
   151  }
   152  
   153  type primitiveSchema struct {
   154  	LogicalType string         `mapstructure:"logicalType"`
   155  	Precision   int            `mapstructure:"precision"`
   156  	Scale       int            `mapstructure:"scale"`
   157  	Props       map[string]any `mapstructure:",remain"`
   158  }
   159  
   160  func parsePrimitive(typ Type, m map[string]any) (Schema, error) {
   161  	if m == nil {
   162  		return NewPrimitiveSchema(typ, nil), nil
   163  	}
   164  
   165  	var (
   166  		p    primitiveSchema
   167  		meta mapstructure.Metadata
   168  	)
   169  	if err := decodeMap(m, &p, &meta); err != nil {
   170  		return nil, fmt.Errorf("avro: error decoding primitive: %w", err)
   171  	}
   172  
   173  	var logical LogicalSchema
   174  	if p.LogicalType != "" {
   175  		logical = parsePrimitiveLogicalType(typ, p.LogicalType, p.Precision, p.Scale)
   176  	}
   177  
   178  	return NewPrimitiveSchema(typ, logical, WithProps(p.Props)), nil
   179  }
   180  
   181  func parsePrimitiveLogicalType(typ Type, lt string, prec, scale int) LogicalSchema {
   182  	ltyp := LogicalType(lt)
   183  	if (typ == String && ltyp == UUID) ||
   184  		(typ == Int && ltyp == Date) ||
   185  		(typ == Int && ltyp == TimeMillis) ||
   186  		(typ == Long && ltyp == TimeMicros) ||
   187  		(typ == Long && ltyp == TimestampMillis) ||
   188  		(typ == Long && ltyp == TimestampMicros) ||
   189  		(typ == Long && ltyp == LocalTimestampMillis) ||
   190  		(typ == Long && ltyp == LocalTimestampMicros) {
   191  		return NewPrimitiveLogicalSchema(ltyp)
   192  	}
   193  
   194  	if typ == Bytes && ltyp == Decimal {
   195  		return parseDecimalLogicalType(-1, prec, scale)
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  type recordSchema struct {
   202  	Type      string           `mapstructure:"type"`
   203  	Name      string           `mapstructure:"name"`
   204  	Namespace string           `mapstructure:"namespace"`
   205  	Aliases   []string         `mapstructure:"aliases"`
   206  	Doc       string           `mapstructure:"doc"`
   207  	Fields    []map[string]any `mapstructure:"fields"`
   208  	Props     map[string]any   `mapstructure:",remain"`
   209  }
   210  
   211  func parseRecord(typ Type, namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   212  	var (
   213  		r    recordSchema
   214  		meta mapstructure.Metadata
   215  	)
   216  	if err := decodeMap(m, &r, &meta); err != nil {
   217  		return nil, fmt.Errorf("avro: error decoding record: %w", err)
   218  	}
   219  
   220  	if err := checkParsedName(r.Name, r.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   221  		return nil, err
   222  	}
   223  	if r.Namespace == "" {
   224  		r.Namespace = namespace
   225  	}
   226  
   227  	if !hasKey(meta.Keys, "fields") {
   228  		return nil, errors.New("avro: record must have an array of fields")
   229  	}
   230  	fields := make([]*Field, len(r.Fields))
   231  
   232  	var (
   233  		rec *RecordSchema
   234  		err error
   235  	)
   236  	switch typ {
   237  	case Record:
   238  		rec, err = NewRecordSchema(r.Name, r.Namespace, fields,
   239  			WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props),
   240  		)
   241  	case Error:
   242  		rec, err = NewErrorRecordSchema(r.Name, r.Namespace, fields,
   243  			WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props),
   244  		)
   245  	}
   246  	if err != nil {
   247  		return nil, err
   248  	}
   249  
   250  	if err = seen.Add(rec.FullName()); err != nil {
   251  		return nil, err
   252  	}
   253  
   254  	ref := NewRefSchema(rec)
   255  	cache.Add(rec.FullName(), ref)
   256  	for _, alias := range rec.Aliases() {
   257  		cache.Add(alias, ref)
   258  	}
   259  
   260  	for i, f := range r.Fields {
   261  		field, err := parseField(rec.namespace, f, seen, cache)
   262  		if err != nil {
   263  			return nil, err
   264  		}
   265  		fields[i] = field
   266  	}
   267  
   268  	return rec, nil
   269  }
   270  
   271  type fieldSchema struct {
   272  	Name    string         `mapstructure:"name"`
   273  	Aliases []string       `mapstructure:"aliases"`
   274  	Type    any            `mapstructure:"type"`
   275  	Doc     string         `mapstructure:"doc"`
   276  	Default any            `mapstructure:"default"`
   277  	Order   Order          `mapstructure:"order"`
   278  	Props   map[string]any `mapstructure:",remain"`
   279  }
   280  
   281  func parseField(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Field, error) {
   282  	var (
   283  		f    fieldSchema
   284  		meta mapstructure.Metadata
   285  	)
   286  	if err := decodeMap(m, &f, &meta); err != nil {
   287  		return nil, fmt.Errorf("avro: error decoding field: %w", err)
   288  	}
   289  
   290  	if err := checkParsedName(f.Name, "", false); err != nil {
   291  		return nil, err
   292  	}
   293  
   294  	if !hasKey(meta.Keys, "type") {
   295  		return nil, errors.New("avro: field requires a type")
   296  	}
   297  	typ, err := parseType(namespace, f.Type, seen, cache)
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  
   302  	if !hasKey(meta.Keys, "default") {
   303  		f.Default = NoDefault
   304  	}
   305  
   306  	field, err := NewField(f.Name, typ,
   307  		WithDefault(f.Default), WithAliases(f.Aliases), WithDoc(f.Doc), WithOrder(f.Order), WithProps(f.Props),
   308  	)
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  
   313  	return field, nil
   314  }
   315  
   316  type enumSchema struct {
   317  	Name      string         `mapstructure:"name"`
   318  	Namespace string         `mapstructure:"namespace"`
   319  	Aliases   []string       `mapstructure:"aliases"`
   320  	Type      string         `mapstructure:"type"`
   321  	Doc       string         `mapstructure:"doc"`
   322  	Symbols   []string       `mapstructure:"symbols"`
   323  	Default   string         `mapstructure:"default"`
   324  	Props     map[string]any `mapstructure:",remain"`
   325  }
   326  
   327  func parseEnum(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   328  	var (
   329  		e    enumSchema
   330  		meta mapstructure.Metadata
   331  	)
   332  	if err := decodeMap(m, &e, &meta); err != nil {
   333  		return nil, fmt.Errorf("avro: error decoding enum: %w", err)
   334  	}
   335  
   336  	if err := checkParsedName(e.Name, e.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   337  		return nil, err
   338  	}
   339  	if e.Namespace == "" {
   340  		e.Namespace = namespace
   341  	}
   342  
   343  	enum, err := NewEnumSchema(e.Name, e.Namespace, e.Symbols,
   344  		WithDefault(e.Default), WithAliases(e.Aliases), WithDoc(e.Doc), WithProps(e.Props),
   345  	)
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  
   350  	if err = seen.Add(enum.FullName()); err != nil {
   351  		return nil, err
   352  	}
   353  
   354  	ref := NewRefSchema(enum)
   355  	cache.Add(enum.FullName(), ref)
   356  	for _, alias := range enum.Aliases() {
   357  		cache.Add(alias, enum)
   358  	}
   359  
   360  	return enum, nil
   361  }
   362  
   363  type arraySchema struct {
   364  	Items any            `mapstructure:"items"`
   365  	Props map[string]any `mapstructure:",remain"`
   366  }
   367  
   368  func parseArray(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   369  	var (
   370  		a    arraySchema
   371  		meta mapstructure.Metadata
   372  	)
   373  	if err := decodeMap(m, &a, &meta); err != nil {
   374  		return nil, fmt.Errorf("avro: error decoding array: %w", err)
   375  	}
   376  
   377  	if !hasKey(meta.Keys, "items") {
   378  		return nil, errors.New("avro: array must have an items key")
   379  	}
   380  	schema, err := parseType(namespace, a.Items, seen, cache)
   381  	if err != nil {
   382  		return nil, err
   383  	}
   384  
   385  	return NewArraySchema(schema, WithProps(a.Props)), nil
   386  }
   387  
   388  type mapSchema struct {
   389  	Values any            `mapstructure:"values"`
   390  	Props  map[string]any `mapstructure:",remain"`
   391  }
   392  
   393  func parseMap(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   394  	var (
   395  		ms   mapSchema
   396  		meta mapstructure.Metadata
   397  	)
   398  	if err := decodeMap(m, &ms, &meta); err != nil {
   399  		return nil, fmt.Errorf("avro: error decoding map: %w", err)
   400  	}
   401  
   402  	if !hasKey(meta.Keys, "values") {
   403  		return nil, errors.New("avro: map must have an values key")
   404  	}
   405  	schema, err := parseType(namespace, ms.Values, seen, cache)
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  
   410  	return NewMapSchema(schema, WithProps(ms.Props)), nil
   411  }
   412  
   413  func parseUnion(namespace string, v []any, seen seenCache, cache *SchemaCache) (Schema, error) {
   414  	var err error
   415  	types := make([]Schema, len(v))
   416  	for i := range v {
   417  		types[i], err = parseType(namespace, v[i], seen, cache)
   418  		if err != nil {
   419  			return nil, err
   420  		}
   421  	}
   422  
   423  	return NewUnionSchema(types)
   424  }
   425  
   426  type fixedSchema struct {
   427  	Name        string         `mapstructure:"name"`
   428  	Namespace   string         `mapstructure:"namespace"`
   429  	Aliases     []string       `mapstructure:"aliases"`
   430  	Type        string         `mapstructure:"type"`
   431  	Size        int            `mapstructure:"size"`
   432  	LogicalType string         `mapstructure:"logicalType"`
   433  	Precision   int            `mapstructure:"precision"`
   434  	Scale       int            `mapstructure:"scale"`
   435  	Props       map[string]any `mapstructure:",remain"`
   436  }
   437  
   438  func parseFixed(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   439  	var (
   440  		f    fixedSchema
   441  		meta mapstructure.Metadata
   442  	)
   443  	if err := decodeMap(m, &f, &meta); err != nil {
   444  		return nil, fmt.Errorf("avro: error decoding fixed: %w", err)
   445  	}
   446  
   447  	if err := checkParsedName(f.Name, f.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   448  		return nil, err
   449  	}
   450  	if f.Namespace == "" {
   451  		f.Namespace = namespace
   452  	}
   453  
   454  	if !hasKey(meta.Keys, "size") {
   455  		return nil, errors.New("avro: fixed must have a size")
   456  	}
   457  
   458  	var logical LogicalSchema
   459  	if f.LogicalType != "" {
   460  		logical = parseFixedLogicalType(f.Size, f.LogicalType, f.Precision, f.Scale)
   461  	}
   462  
   463  	fixed, err := NewFixedSchema(f.Name, f.Namespace, f.Size, logical, WithAliases(f.Aliases), WithProps(f.Props))
   464  	if err != nil {
   465  		return nil, err
   466  	}
   467  
   468  	if err = seen.Add(fixed.FullName()); err != nil {
   469  		return nil, err
   470  	}
   471  
   472  	ref := NewRefSchema(fixed)
   473  	cache.Add(fixed.FullName(), ref)
   474  	for _, alias := range fixed.Aliases() {
   475  		cache.Add(alias, fixed)
   476  	}
   477  
   478  	return fixed, nil
   479  }
   480  
   481  func parseFixedLogicalType(size int, lt string, prec, scale int) LogicalSchema {
   482  	ltyp := LogicalType(lt)
   483  	switch {
   484  	case ltyp == Duration && size == 12:
   485  		return NewPrimitiveLogicalSchema(Duration)
   486  	case ltyp == Decimal:
   487  		return parseDecimalLogicalType(size, prec, scale)
   488  	}
   489  
   490  	return nil
   491  }
   492  
   493  func parseDecimalLogicalType(size, prec, scale int) LogicalSchema {
   494  	if prec <= 0 {
   495  		return nil
   496  	}
   497  
   498  	if size > 0 {
   499  		maxPrecision := int(math.Round(math.Floor(math.Log10(2) * (8*float64(size) - 1))))
   500  		if prec > maxPrecision {
   501  			return nil
   502  		}
   503  	}
   504  
   505  	if scale < 0 {
   506  		return nil
   507  	}
   508  
   509  	// Scale may not be bigger than precision
   510  	if scale > prec {
   511  		return nil
   512  	}
   513  
   514  	return NewDecimalLogicalSchema(prec, scale)
   515  }
   516  
   517  func fullName(namespace, name string) string {
   518  	if len(namespace) == 0 || strings.ContainsRune(name, '.') {
   519  		return name
   520  	}
   521  
   522  	return namespace + "." + name
   523  }
   524  
   525  func checkParsedName(name, ns string, hasNS bool) error {
   526  	if name == "" {
   527  		return errors.New("avro: non-empty name key required")
   528  	}
   529  	if hasNS && ns == "" {
   530  		return errors.New("avro: namespace key must be non-empty or omitted")
   531  	}
   532  	return nil
   533  }
   534  
   535  func hasKey(keys []string, k string) bool {
   536  	for _, key := range keys {
   537  		if key == k {
   538  			return true
   539  		}
   540  	}
   541  	return false
   542  }
   543  
   544  func decodeMap(in, v any, meta *mapstructure.Metadata) error {
   545  	cfg := &mapstructure.DecoderConfig{
   546  		ZeroFields: true,
   547  		Metadata:   meta,
   548  		Result:     v,
   549  	}
   550  
   551  	decoder, _ := mapstructure.NewDecoder(cfg)
   552  	return decoder.Decode(in)
   553  }
   554  
   555  func derefSchema(schema Schema) Schema {
   556  	seen := map[string]struct{}{}
   557  
   558  	return walkSchema(schema, func(schema Schema) Schema {
   559  		if ns, ok := schema.(NamedSchema); ok {
   560  			seen[ns.FullName()] = struct{}{}
   561  			return schema
   562  		}
   563  
   564  		ref, isRef := schema.(*RefSchema)
   565  		if !isRef {
   566  			return schema
   567  		}
   568  
   569  		if _, haveSeen := seen[ref.Schema().FullName()]; !haveSeen {
   570  			seen[ref.Schema().FullName()] = struct{}{}
   571  			return ref.Schema()
   572  		}
   573  		return schema
   574  	})
   575  }
   576  
   577  type seenCache map[string]struct{}
   578  
   579  func (c seenCache) Add(name string) error {
   580  	if _, ok := c[name]; ok {
   581  		return fmt.Errorf("duplicate name %q", name)
   582  	}
   583  	c[name] = struct{}{}
   584  	return nil
   585  }