github.com/aacfactory/avro@v1.2.12/internal/base/schema_parse.go (about)

     1  package base
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math"
     7  	"os"
     8  	"path/filepath"
     9  	"strings"
    10  
    11  	jsoniter "github.com/json-iterator/go"
    12  	"github.com/mitchellh/mapstructure"
    13  )
    14  
    15  // DefaultSchemaCache is the default cache for schemas.
    16  var DefaultSchemaCache = &SchemaCache{}
    17  
    18  // Parse parses a schema string.
    19  func Parse(schema string) (Schema, error) {
    20  	return ParseBytes([]byte(schema))
    21  }
    22  
    23  // ParseWithCache parses a schema string using the given namespace and schema cache.
    24  func ParseWithCache(schema, namespace string, cache *SchemaCache) (Schema, error) {
    25  	return ParseBytesWithCache([]byte(schema), namespace, cache)
    26  }
    27  
    28  // MustParse parses a schema string, panicing if there is an error.
    29  func MustParse(schema string) Schema {
    30  	parsed, err := Parse(schema)
    31  	if err != nil {
    32  		panic(err)
    33  	}
    34  
    35  	return parsed
    36  }
    37  
    38  // ParseFiles parses the schemas in the files, in the order they appear, returning the last schema.
    39  //
    40  // This is useful when your schemas rely on other schemas.
    41  func ParseFiles(paths ...string) (Schema, error) {
    42  	var schema Schema
    43  	for _, path := range paths {
    44  		s, err := os.ReadFile(filepath.Clean(path))
    45  		if err != nil {
    46  			return nil, err
    47  		}
    48  
    49  		schema, err = Parse(string(s))
    50  		if err != nil {
    51  			return nil, err
    52  		}
    53  	}
    54  
    55  	return schema, nil
    56  }
    57  
    58  // ParseBytes parses a schema byte slice.
    59  func ParseBytes(schema []byte) (Schema, error) {
    60  	return ParseBytesWithCache(schema, "", DefaultSchemaCache)
    61  }
    62  
    63  // ParseBytesWithCache parses a schema byte slice using the given namespace and schema cache.
    64  func ParseBytesWithCache(schema []byte, namespace string, cache *SchemaCache) (Schema, error) {
    65  	var json any
    66  	if err := jsoniter.Unmarshal(schema, &json); err != nil {
    67  		json = string(schema)
    68  	}
    69  
    70  	seen := seenCache{}
    71  	s, err := parseType(namespace, json, seen, cache)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	return derefSchema(s), nil
    76  }
    77  
    78  func parseType(namespace string, v any, seen seenCache, cache *SchemaCache) (Schema, error) {
    79  	switch val := v.(type) {
    80  	case nil:
    81  		return &NullSchema{}, nil
    82  
    83  	case string:
    84  		return parsePrimitiveType(namespace, val, cache)
    85  
    86  	case map[string]any:
    87  		return parseComplexType(namespace, val, seen, cache)
    88  
    89  	case []any:
    90  		return parseUnion(namespace, val, seen, cache)
    91  	}
    92  
    93  	return nil, fmt.Errorf("avro: unknown type: %v", v)
    94  }
    95  
    96  func parsePrimitiveType(namespace, s string, cache *SchemaCache) (Schema, error) {
    97  	typ := Type(s)
    98  	switch typ {
    99  	case Null:
   100  		return &NullSchema{}, nil
   101  
   102  	case String, Bytes, Int, Long, Float, Double, Boolean:
   103  		return parsePrimitive(typ, nil)
   104  
   105  	default:
   106  		schema := cache.Get(fullName(namespace, s))
   107  		if schema != nil {
   108  			return schema, nil
   109  		}
   110  
   111  		return nil, fmt.Errorf("avro: unknown type: %s", s)
   112  	}
   113  }
   114  
   115  func parseComplexType(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   116  	if val, ok := m["type"].([]any); ok {
   117  		return parseUnion(namespace, val, seen, cache)
   118  	}
   119  
   120  	str, ok := m["type"].(string)
   121  	if !ok {
   122  		return nil, fmt.Errorf("avro: unknown type: %+v", m)
   123  	}
   124  	typ := Type(str)
   125  
   126  	switch typ {
   127  	case Null:
   128  		return &NullSchema{}, nil
   129  
   130  	case String, Bytes, Int, Long, Float, Double, Boolean:
   131  		return parsePrimitive(typ, m)
   132  
   133  	case Record, Error:
   134  		return parseRecord(typ, namespace, m, seen, cache)
   135  
   136  	case Enum:
   137  		return parseEnum(namespace, m, seen, cache)
   138  
   139  	case Array:
   140  		return parseArray(namespace, m, seen, cache)
   141  
   142  	case Map:
   143  		return parseMap(namespace, m, seen, cache)
   144  
   145  	case Fixed:
   146  		return parseFixed(namespace, m, seen, cache)
   147  
   148  	default:
   149  		return parseType(namespace, string(typ), seen, cache)
   150  	}
   151  }
   152  
   153  type primitiveSchema struct {
   154  	LogicalType string         `mapstructure:"logicalType"`
   155  	Precision   int            `mapstructure:"precision"`
   156  	Scale       int            `mapstructure:"scale"`
   157  	Props       map[string]any `mapstructure:",remain"`
   158  }
   159  
   160  func parsePrimitive(typ Type, m map[string]any) (Schema, error) {
   161  	if m == nil {
   162  		return NewPrimitiveSchema(typ, nil), nil
   163  	}
   164  
   165  	var (
   166  		p    primitiveSchema
   167  		meta mapstructure.Metadata
   168  	)
   169  	if err := decodeMap(m, &p, &meta); err != nil {
   170  		return nil, fmt.Errorf("avro: error decoding primitive: %w", err)
   171  	}
   172  
   173  	var logical LogicalSchema
   174  	if p.LogicalType != "" {
   175  		logical = parsePrimitiveLogicalType(typ, p.LogicalType, p.Precision, p.Scale)
   176  	}
   177  
   178  	return NewPrimitiveSchema(typ, logical, WithProps(p.Props)), nil
   179  }
   180  
   181  func parsePrimitiveLogicalType(typ Type, lt string, prec, scale int) LogicalSchema {
   182  	ltyp := LogicalType(lt)
   183  	if (typ == String && ltyp == UUID) ||
   184  		(typ == Int && ltyp == Date) ||
   185  		(typ == Int && ltyp == TimeMillis) ||
   186  		(typ == Long && ltyp == TimeMicros) ||
   187  		(typ == Long && ltyp == TimestampMillis) ||
   188  		(typ == Long && ltyp == TimestampMicros) {
   189  		return NewPrimitiveLogicalSchema(ltyp)
   190  	}
   191  
   192  	if typ == Bytes && ltyp == Decimal {
   193  		return parseDecimalLogicalType(-1, prec, scale)
   194  	}
   195  
   196  	return nil
   197  }
   198  
   199  type recordSchema struct {
   200  	Type      string           `mapstructure:"type"`
   201  	Name      string           `mapstructure:"name"`
   202  	Namespace string           `mapstructure:"namespace"`
   203  	Aliases   []string         `mapstructure:"aliases"`
   204  	Doc       string           `mapstructure:"doc"`
   205  	Fields    []map[string]any `mapstructure:"fields"`
   206  	Props     map[string]any   `mapstructure:",remain"`
   207  }
   208  
   209  func parseRecord(typ Type, namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   210  	var (
   211  		r    recordSchema
   212  		meta mapstructure.Metadata
   213  	)
   214  	if err := decodeMap(m, &r, &meta); err != nil {
   215  		return nil, fmt.Errorf("avro: error decoding record: %w", err)
   216  	}
   217  
   218  	if err := checkParsedName(r.Name, r.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   219  		return nil, err
   220  	}
   221  	if r.Namespace == "" {
   222  		r.Namespace = namespace
   223  	}
   224  
   225  	if !hasKey(meta.Keys, "fields") {
   226  		return nil, errors.New("avro: record must have an array of fields")
   227  	}
   228  	fields := make([]*Field, len(r.Fields))
   229  
   230  	var (
   231  		rec *RecordSchema
   232  		err error
   233  	)
   234  	switch typ {
   235  	case Record:
   236  		rec, err = NewRecordSchema(r.Name, r.Namespace, fields,
   237  			WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props),
   238  		)
   239  	case Error:
   240  		rec, err = NewErrorRecordSchema(r.Name, r.Namespace, fields,
   241  			WithAliases(r.Aliases), WithDoc(r.Doc), WithProps(r.Props),
   242  		)
   243  	}
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  
   248  	if err = seen.Add(rec.FullName()); err != nil {
   249  		return nil, err
   250  	}
   251  
   252  	ref := NewRefSchema(rec)
   253  	cache.Add(rec.FullName(), ref)
   254  	for _, alias := range rec.Aliases() {
   255  		cache.Add(alias, ref)
   256  	}
   257  
   258  	for i, f := range r.Fields {
   259  		field, err := parseField(rec.namespace, f, seen, cache)
   260  		if err != nil {
   261  			return nil, err
   262  		}
   263  		fields[i] = field
   264  	}
   265  
   266  	return rec, nil
   267  }
   268  
   269  type fieldSchema struct {
   270  	Name    string         `mapstructure:"name"`
   271  	Aliases []string       `mapstructure:"aliases"`
   272  	Type    any            `mapstructure:"type"`
   273  	Doc     string         `mapstructure:"doc"`
   274  	Default any            `mapstructure:"default"`
   275  	Order   Order          `mapstructure:"order"`
   276  	Props   map[string]any `mapstructure:",remain"`
   277  }
   278  
   279  func parseField(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Field, error) {
   280  	var (
   281  		f    fieldSchema
   282  		meta mapstructure.Metadata
   283  	)
   284  	if err := decodeMap(m, &f, &meta); err != nil {
   285  		return nil, fmt.Errorf("avro: error decoding field: %w", err)
   286  	}
   287  
   288  	if err := checkParsedName(f.Name, "", false); err != nil {
   289  		return nil, err
   290  	}
   291  
   292  	if !hasKey(meta.Keys, "type") {
   293  		return nil, errors.New("avro: field requires a type")
   294  	}
   295  	typ, err := parseType(namespace, f.Type, seen, cache)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	if !hasKey(meta.Keys, "default") {
   301  		f.Default = NoDefault
   302  	}
   303  
   304  	field, err := NewField(f.Name, typ,
   305  		WithDefault(f.Default), WithAliases(f.Aliases), WithDoc(f.Doc), WithOrder(f.Order), WithProps(f.Props),
   306  	)
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  
   311  	return field, nil
   312  }
   313  
   314  type enumSchema struct {
   315  	Name      string         `mapstructure:"name"`
   316  	Namespace string         `mapstructure:"namespace"`
   317  	Aliases   []string       `mapstructure:"aliases"`
   318  	Type      string         `mapstructure:"type"`
   319  	Doc       string         `mapstructure:"doc"`
   320  	Symbols   []string       `mapstructure:"symbols"`
   321  	Default   string         `mapstructure:"default"`
   322  	Props     map[string]any `mapstructure:",remain"`
   323  }
   324  
   325  func parseEnum(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   326  	var (
   327  		e    enumSchema
   328  		meta mapstructure.Metadata
   329  	)
   330  	if err := decodeMap(m, &e, &meta); err != nil {
   331  		return nil, fmt.Errorf("avro: error decoding enum: %w", err)
   332  	}
   333  
   334  	if err := checkParsedName(e.Name, e.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   335  		return nil, err
   336  	}
   337  	if e.Namespace == "" {
   338  		e.Namespace = namespace
   339  	}
   340  
   341  	enum, err := NewEnumSchema(e.Name, e.Namespace, e.Symbols,
   342  		WithDefault(e.Default), WithAliases(e.Aliases), WithDoc(e.Doc), WithProps(e.Props),
   343  	)
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	if err = seen.Add(enum.FullName()); err != nil {
   349  		return nil, err
   350  	}
   351  
   352  	ref := NewRefSchema(enum)
   353  	cache.Add(enum.FullName(), ref)
   354  	for _, alias := range enum.Aliases() {
   355  		cache.Add(alias, enum)
   356  	}
   357  
   358  	return enum, nil
   359  }
   360  
   361  type arraySchema struct {
   362  	Items any            `mapstructure:"items"`
   363  	Props map[string]any `mapstructure:",remain"`
   364  }
   365  
   366  func parseArray(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   367  	var (
   368  		a    arraySchema
   369  		meta mapstructure.Metadata
   370  	)
   371  	if err := decodeMap(m, &a, &meta); err != nil {
   372  		return nil, fmt.Errorf("avro: error decoding array: %w", err)
   373  	}
   374  
   375  	if !hasKey(meta.Keys, "items") {
   376  		return nil, errors.New("avro: array must have an items key")
   377  	}
   378  	schema, err := parseType(namespace, a.Items, seen, cache)
   379  	if err != nil {
   380  		return nil, err
   381  	}
   382  
   383  	return NewArraySchema(schema, WithProps(a.Props)), nil
   384  }
   385  
   386  type mapSchema struct {
   387  	Values any            `mapstructure:"values"`
   388  	Props  map[string]any `mapstructure:",remain"`
   389  }
   390  
   391  func parseMap(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   392  	var (
   393  		ms   mapSchema
   394  		meta mapstructure.Metadata
   395  	)
   396  	if err := decodeMap(m, &ms, &meta); err != nil {
   397  		return nil, fmt.Errorf("avro: error decoding map: %w", err)
   398  	}
   399  
   400  	if !hasKey(meta.Keys, "values") {
   401  		return nil, errors.New("avro: map must have an values key")
   402  	}
   403  	schema, err := parseType(namespace, ms.Values, seen, cache)
   404  	if err != nil {
   405  		return nil, err
   406  	}
   407  
   408  	return NewMapSchema(schema, WithProps(ms.Props)), nil
   409  }
   410  
   411  func parseUnion(namespace string, v []any, seen seenCache, cache *SchemaCache) (Schema, error) {
   412  	var err error
   413  	types := make([]Schema, len(v))
   414  	for i := range v {
   415  		types[i], err = parseType(namespace, v[i], seen, cache)
   416  		if err != nil {
   417  			return nil, err
   418  		}
   419  	}
   420  
   421  	return NewUnionSchema(types)
   422  }
   423  
   424  type fixedSchema struct {
   425  	Name        string         `mapstructure:"name"`
   426  	Namespace   string         `mapstructure:"namespace"`
   427  	Aliases     []string       `mapstructure:"aliases"`
   428  	Type        string         `mapstructure:"type"`
   429  	Size        int            `mapstructure:"size"`
   430  	LogicalType string         `mapstructure:"logicalType"`
   431  	Precision   int            `mapstructure:"precision"`
   432  	Scale       int            `mapstructure:"scale"`
   433  	Props       map[string]any `mapstructure:",remain"`
   434  }
   435  
   436  func parseFixed(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (Schema, error) {
   437  	var (
   438  		f    fixedSchema
   439  		meta mapstructure.Metadata
   440  	)
   441  	if err := decodeMap(m, &f, &meta); err != nil {
   442  		return nil, fmt.Errorf("avro: error decoding fixed: %w", err)
   443  	}
   444  
   445  	if err := checkParsedName(f.Name, f.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   446  		return nil, err
   447  	}
   448  	if f.Namespace == "" {
   449  		f.Namespace = namespace
   450  	}
   451  
   452  	if !hasKey(meta.Keys, "size") {
   453  		return nil, errors.New("avro: fixed must have a size")
   454  	}
   455  
   456  	var logical LogicalSchema
   457  	if f.LogicalType != "" {
   458  		logical = parseFixedLogicalType(f.Size, f.LogicalType, f.Precision, f.Scale)
   459  	}
   460  
   461  	fixed, err := NewFixedSchema(f.Name, f.Namespace, f.Size, logical, WithAliases(f.Aliases), WithProps(f.Props))
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  
   466  	if err = seen.Add(fixed.FullName()); err != nil {
   467  		return nil, err
   468  	}
   469  
   470  	ref := NewRefSchema(fixed)
   471  	cache.Add(fixed.FullName(), ref)
   472  	for _, alias := range fixed.Aliases() {
   473  		cache.Add(alias, fixed)
   474  	}
   475  
   476  	return fixed, nil
   477  }
   478  
   479  func parseFixedLogicalType(size int, lt string, prec, scale int) LogicalSchema {
   480  	ltyp := LogicalType(lt)
   481  	switch {
   482  	case ltyp == Duration && size == 12:
   483  		return NewPrimitiveLogicalSchema(Duration)
   484  	case ltyp == Decimal:
   485  		return parseDecimalLogicalType(size, prec, scale)
   486  	}
   487  
   488  	return nil
   489  }
   490  
   491  func parseDecimalLogicalType(size, prec, scale int) LogicalSchema {
   492  	if prec <= 0 {
   493  		return nil
   494  	}
   495  
   496  	if size > 0 {
   497  		maxPrecision := int(math.Round(math.Floor(math.Log10(2) * (8*float64(size) - 1))))
   498  		if prec > maxPrecision {
   499  			return nil
   500  		}
   501  	}
   502  
   503  	if scale < 0 {
   504  		return nil
   505  	}
   506  
   507  	// Scale may not be bigger than precision
   508  	if scale > prec {
   509  		return nil
   510  	}
   511  
   512  	return NewDecimalLogicalSchema(prec, scale)
   513  }
   514  
   515  func fullName(namespace, name string) string {
   516  	if len(namespace) == 0 || strings.ContainsRune(name, '.') {
   517  		return name
   518  	}
   519  
   520  	return namespace + "." + name
   521  }
   522  
   523  func checkParsedName(name, ns string, hasNS bool) error {
   524  	if name == "" {
   525  		return errors.New("avro: non-empty name key required")
   526  	}
   527  	if hasNS && ns == "" {
   528  		return errors.New("avro: namespace key must be non-empty or omitted")
   529  	}
   530  	return nil
   531  }
   532  
   533  func hasKey(keys []string, k string) bool {
   534  	for _, key := range keys {
   535  		if key == k {
   536  			return true
   537  		}
   538  	}
   539  	return false
   540  }
   541  
   542  func decodeMap(in, v any, meta *mapstructure.Metadata) error {
   543  	cfg := &mapstructure.DecoderConfig{
   544  		ZeroFields: true,
   545  		Metadata:   meta,
   546  		Result:     v,
   547  	}
   548  
   549  	decoder, _ := mapstructure.NewDecoder(cfg)
   550  	return decoder.Decode(in)
   551  }
   552  
   553  func derefSchema(schema Schema) Schema {
   554  	seen := map[string]struct{}{}
   555  
   556  	return walkSchema(schema, func(schema Schema) Schema {
   557  		if ns, ok := schema.(NamedSchema); ok {
   558  			seen[ns.FullName()] = struct{}{}
   559  			return schema
   560  		}
   561  
   562  		ref, isRef := schema.(*RefSchema)
   563  		if !isRef {
   564  			return schema
   565  		}
   566  
   567  		if _, haveSeen := seen[ref.Schema().FullName()]; !haveSeen {
   568  			seen[ref.Schema().FullName()] = struct{}{}
   569  			return ref.Schema()
   570  		}
   571  		return schema
   572  	})
   573  }
   574  
   575  type seenCache map[string]struct{}
   576  
   577  func (c seenCache) Add(name string) error {
   578  	if _, ok := c[name]; ok {
   579  		return fmt.Errorf("duplicate name %q", name)
   580  	}
   581  	c[name] = struct{}{}
   582  	return nil
   583  }