github.com/hamba/avro@v1.8.0/schema_compatibility.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"github.com/modern-go/concurrent"
     8  )
     9  
    10  type recursionError struct{}
    11  
    12  func (e recursionError) Error() string {
    13  	return ""
    14  }
    15  
    16  type compatKey struct {
    17  	reader [32]byte
    18  	writer [32]byte
    19  }
    20  
    21  // SchemaCompatibility determines the compatibility of schemas.
    22  type SchemaCompatibility struct {
    23  	cache *concurrent.Map // map[compatKey]error
    24  }
    25  
    26  // NewSchemaCompatibility creates a new schema compatibility instance.
    27  func NewSchemaCompatibility() *SchemaCompatibility {
    28  	return &SchemaCompatibility{
    29  		cache: concurrent.NewMap(),
    30  	}
    31  }
    32  
    33  // Compatible determines the compatibility if the reader and writer schemas.
    34  func (c *SchemaCompatibility) Compatible(reader, writer Schema) error {
    35  	return c.compatible(reader, writer)
    36  }
    37  
    38  func (c *SchemaCompatibility) compatible(reader, writer Schema) error {
    39  	key := compatKey{reader: reader.Fingerprint(), writer: writer.Fingerprint()}
    40  	if err, ok := c.cache.Load(key); ok {
    41  		if _, ok := err.(recursionError); ok {
    42  			// Break the recursion here.
    43  			return nil
    44  		}
    45  
    46  		if err == nil {
    47  			return nil
    48  		}
    49  
    50  		return err.(error)
    51  	}
    52  
    53  	c.cache.Store(key, recursionError{})
    54  	err := c.match(reader, writer)
    55  	if err != nil {
    56  		// We dont want to pay the cost of fmt.Errorf every time
    57  		err = errors.New(err.Error())
    58  	}
    59  	c.cache.Store(key, err)
    60  	return err
    61  }
    62  
    63  func (c *SchemaCompatibility) match(reader, writer Schema) error {
    64  	// If the schema is a reference, get the actual schema
    65  	if reader.Type() == Ref {
    66  		reader = reader.(*RefSchema).Schema()
    67  	}
    68  	if writer.Type() == Ref {
    69  		writer = writer.(*RefSchema).Schema()
    70  	}
    71  
    72  	if reader.Type() != writer.Type() {
    73  		if writer.Type() == Union {
    74  			// Reader must be compatible with all types in writer
    75  			for _, schema := range writer.(*UnionSchema).Types() {
    76  				if err := c.compatible(reader, schema); err != nil {
    77  					return err
    78  				}
    79  			}
    80  
    81  			return nil
    82  		}
    83  
    84  		if reader.Type() == Union {
    85  			// Writer must be compatible with at least one reader schema
    86  			var err error
    87  			for _, schema := range reader.(*UnionSchema).Types() {
    88  				err = c.compatible(schema, writer)
    89  				if err == nil {
    90  					return nil
    91  				}
    92  			}
    93  
    94  			return fmt.Errorf("reader union lacking writer schema %s", writer.Type())
    95  		}
    96  
    97  		switch writer.Type() {
    98  		case Int:
    99  			if reader.Type() == Long || reader.Type() == Float || reader.Type() == Double {
   100  				return nil
   101  			}
   102  
   103  		case Long:
   104  			if reader.Type() == Float || reader.Type() == Double {
   105  				return nil
   106  			}
   107  
   108  		case Float:
   109  			if reader.Type() == Double {
   110  				return nil
   111  			}
   112  
   113  		case String:
   114  			if reader.Type() == Bytes {
   115  				return nil
   116  			}
   117  
   118  		case Bytes:
   119  			if reader.Type() == String {
   120  				return nil
   121  			}
   122  		}
   123  
   124  		return fmt.Errorf("reader schema %s not compatible with writer schema %s", reader.Type(), writer.Type())
   125  	}
   126  
   127  	switch reader.Type() {
   128  	case Array:
   129  		return c.compatible(reader.(*ArraySchema).Items(), writer.(*ArraySchema).Items())
   130  
   131  	case Map:
   132  		return c.compatible(reader.(*MapSchema).Values(), writer.(*MapSchema).Values())
   133  
   134  	case Fixed:
   135  		r := reader.(*FixedSchema)
   136  		w := writer.(*FixedSchema)
   137  
   138  		if err := c.checkSchemaName(r, w); err != nil {
   139  			return err
   140  		}
   141  
   142  		if err := c.checkFixedSize(r, w); err != nil {
   143  			return err
   144  		}
   145  
   146  	case Enum:
   147  		r := reader.(*EnumSchema)
   148  		w := writer.(*EnumSchema)
   149  
   150  		if err := c.checkSchemaName(r, w); err != nil {
   151  			return err
   152  		}
   153  
   154  		if err := c.checkEnumSymbols(r, w); err != nil {
   155  			return err
   156  		}
   157  
   158  	case Record:
   159  		r := reader.(*RecordSchema)
   160  		w := writer.(*RecordSchema)
   161  
   162  		if err := c.checkSchemaName(r, w); err != nil {
   163  			return err
   164  		}
   165  
   166  		if err := c.checkRecordFields(r, w); err != nil {
   167  			return err
   168  		}
   169  
   170  	case Union:
   171  		for _, schema := range writer.(*UnionSchema).Types() {
   172  			if err := c.compatible(reader, schema); err != nil {
   173  				return err
   174  			}
   175  		}
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  func (c *SchemaCompatibility) checkSchemaName(reader, writer NamedSchema) error {
   182  	if reader.FullName() != writer.FullName() {
   183  		return fmt.Errorf("reader schema %s and writer schema %s  names do match", reader.FullName(), writer.FullName())
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  func (c *SchemaCompatibility) checkFixedSize(reader, writer *FixedSchema) error {
   190  	if reader.Size() != writer.Size() {
   191  		return fmt.Errorf("%s reader and writer fixed sizes do not match", reader.FullName())
   192  	}
   193  
   194  	return nil
   195  }
   196  
   197  func (c *SchemaCompatibility) checkEnumSymbols(reader, writer *EnumSchema) error {
   198  	for _, symbol := range writer.Symbols() {
   199  		if !c.contains(reader.Symbols(), symbol) {
   200  			return fmt.Errorf("reader %s is missing symbol %s", reader.FullName(), symbol)
   201  		}
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  func (c *SchemaCompatibility) checkRecordFields(reader, writer *RecordSchema) error {
   208  	for _, field := range reader.Fields() {
   209  		f, ok := c.getField(writer.Fields(), field)
   210  		if !ok {
   211  			if field.HasDefault() {
   212  				continue
   213  			}
   214  
   215  			return fmt.Errorf("reader field %s is missing in writer schema and has no default", field.Name())
   216  		}
   217  
   218  		if err := c.compatible(field.Type(), f.Type()); err != nil {
   219  			return err
   220  		}
   221  	}
   222  
   223  	return nil
   224  }
   225  
   226  func (c *SchemaCompatibility) contains(a []string, s string) bool {
   227  	for _, str := range a {
   228  		if str == s {
   229  			return true
   230  		}
   231  	}
   232  
   233  	return false
   234  }
   235  
   236  func (c *SchemaCompatibility) getField(a []*Field, f *Field) (*Field, bool) {
   237  	for _, field := range a {
   238  		if field.Name() == f.Name() {
   239  			return field, true
   240  		}
   241  	}
   242  
   243  	return nil, false
   244  }