github.com/aacfactory/avro@v1.2.12/internal/base/schema_compatibility.go (about)

     1  package base
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  )
     8  
     9  type recursionError struct{}
    10  
    11  func (e recursionError) Error() string {
    12  	return ""
    13  }
    14  
    15  type compatKey struct {
    16  	reader [32]byte
    17  	writer [32]byte
    18  }
    19  
    20  // SchemaCompatibility determines the compatibility of schemas.
    21  type SchemaCompatibility struct {
    22  	cache sync.Map // map[compatKey]error
    23  }
    24  
    25  // NewSchemaCompatibility creates a new schema compatibility instance.
    26  func NewSchemaCompatibility() *SchemaCompatibility {
    27  	return &SchemaCompatibility{}
    28  }
    29  
    30  // Compatible determines the compatibility if the reader and writer schemas.
    31  func (c *SchemaCompatibility) Compatible(reader, writer Schema) error {
    32  	return c.compatible(reader, writer)
    33  }
    34  
    35  func (c *SchemaCompatibility) compatible(reader, writer Schema) error {
    36  	key := compatKey{reader: reader.Fingerprint(), writer: writer.Fingerprint()}
    37  	if err, ok := c.cache.Load(key); ok {
    38  		if _, ok := err.(recursionError); ok {
    39  			// Break the recursion here.
    40  			return nil
    41  		}
    42  
    43  		if err == nil {
    44  			return nil
    45  		}
    46  
    47  		return err.(error)
    48  	}
    49  
    50  	c.cache.Store(key, recursionError{})
    51  	err := c.match(reader, writer)
    52  	if err != nil {
    53  		// We dont want to pay the cost of fmt.Errorf every time
    54  		err = errors.New(err.Error())
    55  	}
    56  	c.cache.Store(key, err)
    57  	return err
    58  }
    59  
    60  func (c *SchemaCompatibility) match(reader, writer Schema) error {
    61  	// If the schema is a reference, get the actual schema
    62  	if reader.Type() == Ref {
    63  		reader = reader.(*RefSchema).Schema()
    64  	}
    65  	if writer.Type() == Ref {
    66  		writer = writer.(*RefSchema).Schema()
    67  	}
    68  
    69  	if reader.Type() != writer.Type() {
    70  		if writer.Type() == Union {
    71  			// Reader must be compatible with all types in writer
    72  			for _, schema := range writer.(*UnionSchema).Types() {
    73  				if err := c.compatible(reader, schema); err != nil {
    74  					return err
    75  				}
    76  			}
    77  
    78  			return nil
    79  		}
    80  
    81  		if reader.Type() == Union {
    82  			// Writer must be compatible with at least one reader schema
    83  			var err error
    84  			for _, schema := range reader.(*UnionSchema).Types() {
    85  				err = c.compatible(schema, writer)
    86  				if err == nil {
    87  					return nil
    88  				}
    89  			}
    90  
    91  			return fmt.Errorf("reader union lacking writer schema %s", writer.Type())
    92  		}
    93  
    94  		switch writer.Type() {
    95  		case Int:
    96  			if reader.Type() == Long || reader.Type() == Float || reader.Type() == Double {
    97  				return nil
    98  			}
    99  
   100  		case Long:
   101  			if reader.Type() == Float || reader.Type() == Double {
   102  				return nil
   103  			}
   104  
   105  		case Float:
   106  			if reader.Type() == Double {
   107  				return nil
   108  			}
   109  
   110  		case String:
   111  			if reader.Type() == Bytes {
   112  				return nil
   113  			}
   114  
   115  		case Bytes:
   116  			if reader.Type() == String {
   117  				return nil
   118  			}
   119  		}
   120  
   121  		return fmt.Errorf("reader schema %s not compatible with writer schema %s", reader.Type(), writer.Type())
   122  	}
   123  
   124  	switch reader.Type() {
   125  	case Array:
   126  		return c.compatible(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items())
   127  
   128  	case Map:
   129  		return c.compatible(reader.(*MapSchema).Values(), writer.(*MapSchema).Values())
   130  
   131  	case Fixed:
   132  		r := reader.(*FixedSchema)
   133  		w := writer.(*FixedSchema)
   134  
   135  		if err := c.checkSchemaName(r, w); err != nil {
   136  			return err
   137  		}
   138  
   139  		if err := c.checkFixedSize(r, w); err != nil {
   140  			return err
   141  		}
   142  
   143  	case Enum:
   144  		r := reader.(*EnumSchema)
   145  		w := writer.(*EnumSchema)
   146  
   147  		if err := c.checkSchemaName(r, w); err != nil {
   148  			return err
   149  		}
   150  
   151  		if err := c.checkEnumSymbols(r, w); err != nil {
   152  			return err
   153  		}
   154  
   155  	case Record:
   156  		r := reader.(*RecordSchema)
   157  		w := writer.(*RecordSchema)
   158  
   159  		if err := c.checkSchemaName(r, w); err != nil {
   160  			return err
   161  		}
   162  
   163  		if err := c.checkRecordFields(r, w); err != nil {
   164  			return err
   165  		}
   166  
   167  	case Union:
   168  		for _, schema := range writer.(*UnionSchema).Types() {
   169  			if err := c.compatible(reader, schema); err != nil {
   170  				return err
   171  			}
   172  		}
   173  	}
   174  
   175  	return nil
   176  }
   177  
   178  func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error {
   179  	if reader.FullName() != writer.FullName() {
   180  		return fmt.Errorf("reader schema %s and writer schema %s  names do not match", reader.FullName(), writer.FullName())
   181  	}
   182  
   183  	return nil
   184  }
   185  
   186  func (c *SchemaCompatibility) checkFixedSize(reader, writer *FixedSchema) error {
   187  	if reader.Size() != writer.Size() {
   188  		return fmt.Errorf("%s reader and writer fixed sizes do not match", reader.FullName())
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  func (c *SchemaCompatibility) checkEnumSymbols(reader, writer *EnumSchema) error {
   195  	for _, symbol := range writer.Symbols() {
   196  		if !c.contains(reader.Symbols(), symbol) {
   197  			return fmt.Errorf("reader %s is missing symbol %s", reader.FullName(), symbol)
   198  		}
   199  	}
   200  
   201  	return nil
   202  }
   203  
   204  func (c *SchemaCompatibility) checkRecordFields(reader, writer *RecordSchema) error {
   205  	for _, field := range reader.Fields() {
   206  		f, ok := c.getField(writer.Fields(), field)
   207  		if !ok {
   208  			if field.HasDefault() {
   209  				continue
   210  			}
   211  
   212  			return fmt.Errorf("reader field %s is missing in writer schema and has no default", field.Name())
   213  		}
   214  
   215  		if err := c.compatible(field.Type(), f.Type()); err != nil {
   216  			return err
   217  		}
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  func (c *SchemaCompatibility) contains(a []string, s string) bool {
   224  	for _, str := range a {
   225  		if str == s {
   226  			return true
   227  		}
   228  	}
   229  
   230  	return false
   231  }
   232  
   233  func (c *SchemaCompatibility) getField(a []*Field, f *Field) (*Field, bool) {
   234  	for _, field := range a {
   235  		if field.Name() == f.Name() {
   236  			return field, true
   237  		}
   238  	}
   239  
   240  	return nil, false
   241  }