github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/schema_compatibility.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  )
     8  
     9  type recursionError struct{}
    10  
    11  func (e recursionError) Error() string {
    12  	return ""
    13  }
    14  
    15  type compatKey struct {
    16  	reader [32]byte
    17  	writer [32]byte
    18  }
    19  
    20  // SchemaCompatibility determines the compatibility of schemas.
    21  type SchemaCompatibility struct {
    22  	cache sync.Map // map[compatKey]error
    23  }
    24  
    25  // NewSchemaCompatibility creates a new schema compatibility instance.
    26  func NewSchemaCompatibility() *SchemaCompatibility {
    27  	return &SchemaCompatibility{}
    28  }
    29  
    30  // Compatible determines the compatibility if the reader and writer schemas.
    31  func (c *SchemaCompatibility) Compatible(reader, writer Schema) error {
    32  	return c.compatible(reader, writer)
    33  }
    34  
    35  func (c *SchemaCompatibility) compatible(reader, writer Schema) error {
    36  	key := compatKey{reader: reader.Fingerprint(), writer: writer.Fingerprint()}
    37  	if err, ok := c.cache.Load(key); ok {
    38  		if _, ok := err.(recursionError); ok {
    39  			// Break the recursion here.
    40  			return nil
    41  		}
    42  
    43  		if err == nil {
    44  			return nil
    45  		}
    46  
    47  		return err.(error)
    48  	}
    49  
    50  	c.cache.Store(key, recursionError{})
    51  	err := c.match(reader, writer)
    52  	if err != nil {
    53  		// We dont want to pay the cost of fmt.Errorf every time
    54  		err = errors.New(err.Error())
    55  	}
    56  	c.cache.Store(key, err)
    57  	return err
    58  }
    59  
    60  func (c *SchemaCompatibility) match(reader, writer Schema) error {
    61  	// If the schema is a reference, get the actual schema
    62  	if reader.Type() == Ref {
    63  		reader = reader.(*RefSchema).Schema()
    64  	}
    65  	if writer.Type() == Ref {
    66  		writer = writer.(*RefSchema).Schema()
    67  	}
    68  
    69  	if reader.Type() != writer.Type() {
    70  		if writer.Type() == Union {
    71  			// Reader must be compatible with all types in writer
    72  			for _, schema := range writer.(*UnionSchema).Types() {
    73  				if err := c.compatible(reader, schema); err != nil {
    74  					return err
    75  				}
    76  			}
    77  
    78  			return nil
    79  		}
    80  
    81  		if reader.Type() == Union {
    82  			// Writer must be compatible with at least one reader schema
    83  			var err error
    84  			for _, schema := range reader.(*UnionSchema).Types() {
    85  				err = c.compatible(schema, writer)
    86  				if err == nil {
    87  					return nil
    88  				}
    89  			}
    90  
    91  			return fmt.Errorf("reader union lacking writer schema %s", writer.Type())
    92  		}
    93  
    94  		switch writer.Type() {
    95  		case Int:
    96  			if reader.Type() == Long || reader.Type() == Float || reader.Type() == Double {
    97  				return nil
    98  			}
    99  
   100  		case Long:
   101  			if reader.Type() == Float || reader.Type() == Double {
   102  				return nil
   103  			}
   104  
   105  		case Float:
   106  			if reader.Type() == Double {
   107  				return nil
   108  			}
   109  
   110  		case String:
   111  			if reader.Type() == Bytes {
   112  				return nil
   113  			}
   114  
   115  		case Bytes:
   116  			if reader.Type() == String {
   117  				return nil
   118  			}
   119  		}
   120  
   121  		return fmt.Errorf("reader schema %s not compatible with writer schema %s", reader.Type(), writer.Type())
   122  	}
   123  
   124  	switch reader.Type() {
   125  	case Array:
   126  		return c.compatible(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items())
   127  
   128  	case Map:
   129  		return c.compatible(reader.(*MapSchema).Values(), writer.(*MapSchema).Values())
   130  
   131  	case Fixed:
   132  		r := reader.(*FixedSchema)
   133  		w := writer.(*FixedSchema)
   134  
   135  		if err := c.checkSchemaName(r, w); err != nil {
   136  			return err
   137  		}
   138  
   139  		if err := c.checkFixedSize(r, w); err != nil {
   140  			return err
   141  		}
   142  
   143  	case Enum:
   144  		r := reader.(*EnumSchema)
   145  		w := writer.(*EnumSchema)
   146  
   147  		if err := c.checkSchemaName(r, w); err != nil {
   148  			return err
   149  		}
   150  
   151  		if err := c.checkEnumSymbols(r, w); err != nil {
   152  			if r.HasDefault() {
   153  				return nil
   154  			}
   155  			return err
   156  		}
   157  
   158  	case Record:
   159  		r := reader.(*RecordSchema)
   160  		w := writer.(*RecordSchema)
   161  
   162  		if err := c.checkSchemaName(r, w); err != nil {
   163  			return err
   164  		}
   165  
   166  		if err := c.checkRecordFields(r, w); err != nil {
   167  			return err
   168  		}
   169  
   170  	case Union:
   171  		for _, schema := range writer.(*UnionSchema).Types() {
   172  			if err := c.compatible(reader, schema); err != nil {
   173  				return err
   174  			}
   175  		}
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error {
   182  	if reader.Name() != writer.Name() {
   183  		if c.contains(reader.Aliases(), writer.FullName()) {
   184  			return nil
   185  		}
   186  		return fmt.Errorf("reader schema %s and writer schema %s  names do not match", reader.FullName(), writer.FullName())
   187  	}
   188  
   189  	return nil
   190  }
   191  
   192  func (c *SchemaCompatibility) checkFixedSize(reader, writer *FixedSchema) error {
   193  	if reader.Size() != writer.Size() {
   194  		return fmt.Errorf("%s reader and writer fixed sizes do not match", reader.FullName())
   195  	}
   196  
   197  	return nil
   198  }
   199  
   200  func (c *SchemaCompatibility) checkEnumSymbols(reader, writer *EnumSchema) error {
   201  	for _, symbol := range writer.Symbols() {
   202  		if !c.contains(reader.Symbols(), symbol) {
   203  			return fmt.Errorf("reader %s is missing symbol %s", reader.FullName(), symbol)
   204  		}
   205  	}
   206  
   207  	return nil
   208  }
   209  
   210  func (c *SchemaCompatibility) checkRecordFields(reader, writer *RecordSchema) error {
   211  	for _, field := range reader.Fields() {
   212  		f, ok := c.getField(writer.Fields(), field, func(gfo *getFieldOptions) {
   213  			gfo.fieldAlias = true
   214  		})
   215  		if !ok {
   216  			if field.HasDefault() {
   217  				continue
   218  			}
   219  
   220  			return fmt.Errorf("reader field %s is missing in writer schema and has no default", field.Name())
   221  		}
   222  
   223  		if err := c.compatible(field.Type(), f.Type()); err != nil {
   224  			return err
   225  		}
   226  	}
   227  
   228  	return nil
   229  }
   230  
   231  func (c *SchemaCompatibility) contains(a []string, s string) bool {
   232  	for _, str := range a {
   233  		if str == s {
   234  			return true
   235  		}
   236  	}
   237  
   238  	return false
   239  }
   240  
   241  type getFieldOptions struct {
   242  	fieldAlias bool
   243  	elemAlias  bool
   244  }
   245  
   246  func (c *SchemaCompatibility) getField(a []*Field, f *Field, optFns ...func(*getFieldOptions)) (*Field, bool) {
   247  	opt := getFieldOptions{}
   248  	for _, fn := range optFns {
   249  		fn(&opt)
   250  	}
   251  	for _, field := range a {
   252  		if field.Name() == f.Name() {
   253  			return field, true
   254  		}
   255  		if opt.fieldAlias {
   256  			if c.contains(f.Aliases(), field.Name()) {
   257  				return field, true
   258  			}
   259  		}
   260  		if opt.elemAlias {
   261  			if c.contains(field.Aliases(), f.Name()) {
   262  				return field, true
   263  			}
   264  		}
   265  	}
   266  
   267  	return nil, false
   268  }
   269  
   270  // Resolve returns a composite schema that allows decoding data written by the writer schema,
   271  // and makes necessary adjustments to support the reader schema.
   272  //
   273  // It fails if the writer and reader schemas are not compatible.
   274  func (c *SchemaCompatibility) Resolve(reader, writer Schema) (Schema, error) {
   275  	if err := c.compatible(reader, writer); err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	schema, _, err := c.resolve(reader, writer)
   280  	return schema, err
   281  }
   282  
   283  // resolve requires the reader's schema to be already compatible with the writer's.
   284  func (c *SchemaCompatibility) resolve(reader, writer Schema) (schema Schema, resolved bool, err error) {
   285  	if reader.Type() == Ref {
   286  		reader = reader.(*RefSchema).Schema()
   287  	}
   288  	if writer.Type() == Ref {
   289  		writer = writer.(*RefSchema).Schema()
   290  	}
   291  
   292  	if writer.Type() != reader.Type() {
   293  		if reader.Type() == Union {
   294  			for _, schema := range reader.(*UnionSchema).Types() {
   295  				// Compatibility is not guaranteed for every Union reader schema.
   296  				// Therefore, we need to check compatibility in every iteration.
   297  				if err := c.compatible(schema, writer); err != nil {
   298  					continue
   299  				}
   300  				sch, _, err := c.resolve(schema, writer)
   301  				if err != nil {
   302  					continue
   303  				}
   304  				return sch, true, nil
   305  			}
   306  
   307  			return nil, false, fmt.Errorf("reader union lacking writer schema %s", writer.Type())
   308  		}
   309  
   310  		if writer.Type() == Union {
   311  			schemas := make([]Schema, 0)
   312  			for _, schema := range writer.(*UnionSchema).Types() {
   313  				sch, _, err := c.resolve(reader, schema)
   314  				if err != nil {
   315  					return nil, false, err
   316  				}
   317  				schemas = append(schemas, sch)
   318  			}
   319  			s, err := NewUnionSchema(schemas, withWriterFingerprint(writer.Fingerprint()))
   320  			return s, true, err
   321  		}
   322  
   323  		if isPromotable(writer.Type(), reader.Type()) {
   324  			r := NewPrimitiveSchema(reader.Type(), reader.(*PrimitiveSchema).Logical(),
   325  				withWriterFingerprint(writer.Fingerprint()),
   326  			)
   327  			r.encodedType = writer.Type()
   328  			return r, true, nil
   329  		}
   330  
   331  		return nil, false, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type())
   332  	}
   333  
   334  	if isNative(writer.Type()) {
   335  		return reader, false, nil
   336  	}
   337  
   338  	if writer.Type() == Enum {
   339  		r := reader.(*EnumSchema)
   340  		w := writer.(*EnumSchema)
   341  		if err = c.checkEnumSymbols(r, w); err != nil {
   342  			if r.HasDefault() {
   343  				enum, _ := NewEnumSchema(r.Name(), r.Namespace(), r.Symbols(),
   344  					WithAliases(r.Aliases()),
   345  					WithDefault(r.Default()),
   346  					withWriterFingerprint(w.Fingerprint()),
   347  				)
   348  				enum.encodedSymbols = w.Symbols()
   349  				return enum, true, nil
   350  			}
   351  
   352  			return nil, false, err
   353  		}
   354  		return reader, false, nil
   355  	}
   356  
   357  	if writer.Type() == Fixed {
   358  		return reader, false, nil
   359  	}
   360  
   361  	if writer.Type() == Union {
   362  		schemas := make([]Schema, 0)
   363  		for _, s := range writer.(*UnionSchema).Types() {
   364  			sch, resolv, err := c.resolve(reader, s)
   365  			if err != nil {
   366  				return nil, false, err
   367  			}
   368  			schemas = append(schemas, sch)
   369  			resolved = resolv || resolved
   370  		}
   371  		s, err := NewUnionSchema(schemas, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved))
   372  		if err != nil {
   373  			return nil, false, err
   374  		}
   375  		return s, resolved, nil
   376  	}
   377  
   378  	if writer.Type() == Array {
   379  		schema, resolved, err = c.resolve(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items())
   380  		if err != nil {
   381  			return nil, false, err
   382  		}
   383  		return NewArraySchema(schema, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)), resolved, nil
   384  	}
   385  
   386  	if writer.Type() == Map {
   387  		schema, resolved, err = c.resolve(reader.(*MapSchema).Values(), writer.(*MapSchema).Values())
   388  		if err != nil {
   389  			return nil, false, err
   390  		}
   391  		return NewMapSchema(schema, withWriterFingerprintIfResolved(writer.Fingerprint(), resolved)), resolved, nil
   392  	}
   393  
   394  	if writer.Type() == Record {
   395  		return c.resolveRecord(reader, writer)
   396  	}
   397  
   398  	return nil, false, fmt.Errorf("failed to resolve composite schema for %s and %s", reader.Type(), writer.Type())
   399  }
   400  
   401  func (c *SchemaCompatibility) resolveRecord(reader, writer Schema) (Schema, bool, error) {
   402  	w := writer.(*RecordSchema)
   403  	r := reader.(*RecordSchema)
   404  
   405  	fields := make([]*Field, 0)
   406  	seen := make(map[string]struct{})
   407  
   408  	var resolved bool
   409  	for _, wf := range w.Fields() {
   410  		rf, ok := c.getField(r.Fields(), wf, func(gfo *getFieldOptions) {
   411  			gfo.elemAlias = true
   412  		})
   413  		if !ok {
   414  			// The field was not found in the reader schema, it should be ignored.
   415  			f, _ := NewField(wf.Name(), wf.Type(), WithAliases(wf.aliases), WithOrder(wf.order))
   416  			f.def = wf.def
   417  			f.hasDef = wf.hasDef
   418  			f.action = FieldIgnore
   419  			fields = append(fields, f)
   420  
   421  			resolved = true
   422  			continue
   423  		}
   424  
   425  		ft, resolv, err := c.resolve(rf.Type(), wf.Type())
   426  		if err != nil {
   427  			return nil, false, err
   428  		}
   429  		f, _ := NewField(rf.Name(), ft, WithAliases(rf.aliases), WithOrder(rf.order))
   430  		f.def = rf.def
   431  		f.hasDef = rf.hasDef
   432  		fields = append(fields, f)
   433  		resolved = resolv || resolved
   434  
   435  		seen[rf.Name()] = struct{}{}
   436  	}
   437  
   438  	for _, rf := range r.Fields() {
   439  		if _, ok := seen[rf.Name()]; ok {
   440  			// This field has already been seen.
   441  			continue
   442  		}
   443  
   444  		// The schemas are already known to be compatible, so there must be a default on
   445  		// the field in the writer. Use the default.
   446  
   447  		f, _ := NewField(rf.Name(), rf.Type(), WithAliases(rf.aliases), WithOrder(rf.order))
   448  		f.def = rf.def
   449  		f.hasDef = rf.hasDef
   450  		f.action = FieldSetDefault
   451  		fields = append(fields, f)
   452  
   453  		resolved = true
   454  	}
   455  
   456  	schema, err := NewRecordSchema(r.Name(), r.Namespace(), fields,
   457  		WithAliases(r.Aliases()),
   458  		withWriterFingerprintIfResolved(writer.Fingerprint(), resolved),
   459  	)
   460  	return schema, resolved, err
   461  }
   462  
   463  func isNative(typ Type) bool {
   464  	switch typ {
   465  	case Null, Boolean, Int, Long, Float, Double, Bytes, String:
   466  		return true
   467  	default:
   468  		return false
   469  	}
   470  }
   471  
   472  func isPromotable(writerTyp, readerType Type) bool {
   473  	switch writerTyp {
   474  	case Int:
   475  		return readerType == Long || readerType == Float || readerType == Double
   476  	case Long:
   477  		return readerType == Float || readerType == Double
   478  	case Float:
   479  		return readerType == Double
   480  	case String:
   481  		return readerType == Bytes
   482  	case Bytes:
   483  		return readerType == String
   484  	default:
   485  		return false
   486  	}
   487  }