github.com/hamba/avro@v1.8.0/codec_record.go (about)

     1  package avro
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"unsafe"
     9  
    10  	"github.com/modern-go/reflect2"
    11  )
    12  
    13  func createDecoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    14  	switch typ.Kind() {
    15  	case reflect.Struct:
    16  		return decoderOfStruct(cfg, schema, typ)
    17  
    18  	case reflect.Map:
    19  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    20  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    21  			break
    22  		}
    23  		return decoderOfRecord(cfg, schema, typ)
    24  
    25  	case reflect.Ptr:
    26  		return decoderOfPtr(cfg, schema, typ)
    27  
    28  	case reflect.Interface:
    29  		if ifaceType, ok := typ.(*reflect2.UnsafeIFaceType); ok {
    30  			return &recordIfaceDecoder{schema: schema, valType: ifaceType}
    31  		}
    32  	}
    33  
    34  	return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())}
    35  }
    36  
    37  func createEncoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
    38  	switch typ.Kind() {
    39  	case reflect.Struct:
    40  		return encoderOfStruct(cfg, schema, typ)
    41  
    42  	case reflect.Map:
    43  		if typ.(reflect2.MapType).Key().Kind() != reflect.String ||
    44  			typ.(reflect2.MapType).Elem().Kind() != reflect.Interface {
    45  			break
    46  		}
    47  		return encoderOfRecord(cfg, schema, typ)
    48  
    49  	case reflect.Ptr:
    50  		return encoderOfPtr(cfg, schema, typ)
    51  	}
    52  
    53  	return &errorEncoder{err: fmt.Errorf("avro: %s is unsupported for avro %s", typ.String(), schema.Type())}
    54  }
    55  
    56  func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
    57  	rec := schema.(*RecordSchema)
    58  	structDesc := describeStruct(cfg.getTagKey(), typ)
    59  
    60  	fields := make([]*structFieldDecoder, 0, len(rec.Fields()))
    61  	for _, field := range rec.Fields() {
    62  		sf := structDesc.Fields.Get(field.Name())
    63  
    64  		// Skip field if it doesnt exist
    65  		if sf == nil {
    66  			fields = append(fields, &structFieldDecoder{
    67  				decoder: createSkipDecoder(field.Type()),
    68  			})
    69  			continue
    70  		}
    71  
    72  		dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type())
    73  		fields = append(fields, &structFieldDecoder{
    74  			field:   sf.Field,
    75  			decoder: dec,
    76  		})
    77  	}
    78  
    79  	return &structDecoder{typ: typ, fields: fields}
    80  }
    81  
    82  type structFieldDecoder struct {
    83  	field   []*reflect2.UnsafeStructField
    84  	decoder ValDecoder
    85  }
    86  
    87  type structDecoder struct {
    88  	typ    reflect2.Type
    89  	fields []*structFieldDecoder
    90  }
    91  
    92  func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
    93  	for _, field := range d.fields {
    94  		// Skip case
    95  		if field.field == nil {
    96  			field.decoder.Decode(nil, r)
    97  			continue
    98  		}
    99  
   100  		fieldPtr := ptr
   101  		for i, f := range field.field {
   102  			fieldPtr = f.UnsafeGet(fieldPtr)
   103  
   104  			if i == len(field.field)-1 {
   105  				break
   106  			}
   107  
   108  			if f.Type().Kind() == reflect.Ptr {
   109  				if *((*unsafe.Pointer)(ptr)) == nil {
   110  					newPtr := f.Type().UnsafeNew()
   111  					*((*unsafe.Pointer)(fieldPtr)) = newPtr
   112  				}
   113  
   114  				fieldPtr = *((*unsafe.Pointer)(fieldPtr))
   115  			}
   116  		}
   117  		field.decoder.Decode(fieldPtr, r)
   118  
   119  		if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   120  			for _, f := range field.field {
   121  				r.Error = fmt.Errorf("%s: %w", f.Name(), r.Error)
   122  			}
   123  			return
   124  		}
   125  	}
   126  }
   127  
   128  func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   129  	rec := schema.(*RecordSchema)
   130  	structDesc := describeStruct(cfg.getTagKey(), typ)
   131  
   132  	fields := make([]*structFieldEncoder, 0, len(rec.Fields()))
   133  	for _, field := range rec.Fields() {
   134  		sf := structDesc.Fields.Get(field.Name())
   135  
   136  		if sf == nil {
   137  			if !field.HasDefault() {
   138  				// In all other cases, this is a required field
   139  				return &errorEncoder{err: fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name())}
   140  			}
   141  
   142  			def := field.Default()
   143  			if field.Default() == nil {
   144  				if field.Type().Type() == Null {
   145  					// We write nothing in a Null case, just skip it
   146  					continue
   147  				}
   148  
   149  				if field.Type().Type() == Union && field.Type().(*UnionSchema).Nullable() {
   150  					defaultType := reflect2.TypeOf(&def)
   151  					fields = append(fields, &structFieldEncoder{
   152  						defaultPtr: reflect2.PtrOf(&def),
   153  						encoder:    encoderOfPtrUnion(cfg, field.Type(), defaultType),
   154  					})
   155  					continue
   156  				}
   157  			}
   158  
   159  			defaultType := reflect2.TypeOf(def)
   160  			fields = append(fields, &structFieldEncoder{
   161  				defaultPtr: reflect2.PtrOf(def),
   162  				encoder:    encoderOfType(cfg, field.Type(), defaultType),
   163  			})
   164  
   165  			continue
   166  		}
   167  
   168  		fields = append(fields, &structFieldEncoder{
   169  			field:   sf.Field,
   170  			encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()),
   171  		})
   172  	}
   173  
   174  	return &structEncoder{typ: typ, fields: fields}
   175  }
   176  
   177  type structFieldEncoder struct {
   178  	field      []*reflect2.UnsafeStructField
   179  	defaultPtr unsafe.Pointer
   180  	encoder    ValEncoder
   181  }
   182  
   183  type structEncoder struct {
   184  	typ    reflect2.Type
   185  	fields []*structFieldEncoder
   186  }
   187  
   188  func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   189  	for _, field := range e.fields {
   190  		// Default case
   191  		if field.field == nil {
   192  			field.encoder.Encode(field.defaultPtr, w)
   193  			continue
   194  		}
   195  
   196  		fieldPtr := ptr
   197  		for i, f := range field.field {
   198  			fieldPtr = f.UnsafeGet(fieldPtr)
   199  
   200  			if i == len(field.field)-1 {
   201  				break
   202  			}
   203  
   204  			if f.Type().Kind() == reflect.Ptr {
   205  				if *((*unsafe.Pointer)(ptr)) == nil {
   206  					w.Error = fmt.Errorf("embedded field %q is nil", f.Name())
   207  					return
   208  				}
   209  
   210  				fieldPtr = *((*unsafe.Pointer)(fieldPtr))
   211  			}
   212  		}
   213  		field.encoder.Encode(fieldPtr, w)
   214  
   215  		if w.Error != nil && !errors.Is(w.Error, io.EOF) {
   216  			for _, f := range field.field {
   217  				w.Error = fmt.Errorf("%s: %w", f.Name(), w.Error)
   218  			}
   219  		}
   220  	}
   221  }
   222  
   223  func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
   224  	rec := schema.(*RecordSchema)
   225  	mapType := typ.(*reflect2.UnsafeMapType)
   226  
   227  	fields := make([]recordMapDecoderField, len(rec.Fields()))
   228  	for i, field := range rec.Fields() {
   229  		fields[i] = recordMapDecoderField{
   230  			name:    field.Name(),
   231  			decoder: decoderOfType(cfg, field.Type(), mapType.Elem()),
   232  		}
   233  	}
   234  
   235  	return &recordMapDecoder{
   236  		mapType:  mapType,
   237  		elemType: mapType.Elem(),
   238  		fields:   fields,
   239  	}
   240  }
   241  
   242  type recordMapDecoderField struct {
   243  	name    string
   244  	decoder ValDecoder
   245  }
   246  
   247  type recordMapDecoder struct {
   248  	mapType  *reflect2.UnsafeMapType
   249  	elemType reflect2.Type
   250  	fields   []recordMapDecoderField
   251  }
   252  
   253  func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   254  	if d.mapType.UnsafeIsNil(ptr) {
   255  		d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0))
   256  	}
   257  
   258  	for _, field := range d.fields {
   259  		elem := d.elemType.UnsafeNew()
   260  		field.decoder.Decode(elem, r)
   261  
   262  		d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem)
   263  	}
   264  
   265  	if r.Error != nil && !errors.Is(r.Error, io.EOF) {
   266  		r.Error = fmt.Errorf("%v: %w", d.mapType, r.Error)
   267  	}
   268  }
   269  
   270  func encoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
   271  	rec := schema.(*RecordSchema)
   272  	mapType := typ.(*reflect2.UnsafeMapType)
   273  
   274  	fields := make([]mapEncoderField, len(rec.Fields()))
   275  	for i, field := range rec.Fields() {
   276  		fields[i] = mapEncoderField{
   277  			name:    field.Name(),
   278  			hasDef:  field.HasDefault(),
   279  			def:     field.Default(),
   280  			encoder: encoderOfType(cfg, field.Type(), mapType.Elem()),
   281  		}
   282  
   283  		if field.HasDefault() {
   284  			switch {
   285  			case field.Type().Type() == Union:
   286  				union := field.Type().(*UnionSchema)
   287  				fields[i].def = map[string]interface{}{
   288  					string(union.Types()[0].Type()): field.Default(),
   289  				}
   290  			case field.Default() == nil:
   291  				continue
   292  			}
   293  
   294  			defaultType := reflect2.TypeOf(fields[i].def)
   295  			fields[i].defEncoder = encoderOfType(cfg, field.Type(), defaultType)
   296  			if defaultType.LikePtr() {
   297  				fields[i].defEncoder = &onePtrEncoder{fields[i].defEncoder}
   298  			}
   299  		}
   300  	}
   301  
   302  	return &recordMapEncoder{
   303  		mapType: mapType,
   304  		fields:  fields,
   305  	}
   306  }
   307  
   308  type mapEncoderField struct {
   309  	name       string
   310  	hasDef     bool
   311  	def        interface{}
   312  	defEncoder ValEncoder
   313  	encoder    ValEncoder
   314  }
   315  
   316  type recordMapEncoder struct {
   317  	mapType *reflect2.UnsafeMapType
   318  	fields  []mapEncoderField
   319  }
   320  
   321  func (e *recordMapEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
   322  	for _, field := range e.fields {
   323  		valPtr := e.mapType.UnsafeGetIndex(ptr, reflect2.PtrOf(field))
   324  		if valPtr == nil {
   325  			// Missing required field
   326  			if !field.hasDef {
   327  				w.Error = fmt.Errorf("avro: missing required field %s", field.name)
   328  				return
   329  			}
   330  
   331  			// Null default
   332  			if field.def == nil {
   333  				continue
   334  			}
   335  
   336  			defPtr := reflect2.PtrOf(field.def)
   337  			field.defEncoder.Encode(defPtr, w)
   338  			continue
   339  		}
   340  
   341  		field.encoder.Encode(valPtr, w)
   342  	}
   343  }
   344  
   345  type recordIfaceDecoder struct {
   346  	schema  Schema
   347  	valType *reflect2.UnsafeIFaceType
   348  }
   349  
   350  func (d *recordIfaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
   351  	obj := d.valType.UnsafeIndirect(ptr)
   352  	if reflect2.IsNil(obj) {
   353  		r.ReportError("decode non empty interface", "can not unmarshal into nil")
   354  		return
   355  	}
   356  
   357  	r.ReadVal(d.schema, obj)
   358  }
   359  
   360  type structDescriptor struct {
   361  	Type   reflect2.Type
   362  	Fields structFields
   363  }
   364  
   365  type structFields []*structField
   366  
   367  func (sf structFields) Get(name string) *structField {
   368  	for _, f := range sf {
   369  		if f.Name == name {
   370  			return f
   371  		}
   372  	}
   373  
   374  	return nil
   375  }
   376  
   377  type structField struct {
   378  	Name  string
   379  	Field []*reflect2.UnsafeStructField
   380  
   381  	anon *reflect2.UnsafeStructType
   382  }
   383  
   384  func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor {
   385  	structType := typ.(*reflect2.UnsafeStructType)
   386  	fields := structFields{}
   387  
   388  	var curr []structField
   389  	next := []structField{{anon: structType}}
   390  
   391  	visited := map[uintptr]bool{}
   392  
   393  	for len(next) > 0 {
   394  		curr, next = next, curr[:0]
   395  
   396  		for _, f := range curr {
   397  			rtype := f.anon.RType()
   398  			if visited[f.anon.RType()] {
   399  				continue
   400  			}
   401  			visited[rtype] = true
   402  
   403  			for i := 0; i < f.anon.NumField(); i++ {
   404  				field := f.anon.Field(i).(*reflect2.UnsafeStructField)
   405  				isUnexported := field.PkgPath() != ""
   406  
   407  				chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1)
   408  				copy(chain, f.Field)
   409  				chain[len(f.Field)] = field
   410  
   411  				if field.Anonymous() {
   412  					t := field.Type()
   413  					if t.Kind() == reflect.Ptr {
   414  						t = t.(*reflect2.UnsafePtrType).Elem()
   415  					}
   416  					if t.Kind() != reflect.Struct {
   417  						continue
   418  					}
   419  
   420  					next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)})
   421  					continue
   422  				}
   423  
   424  				// Ignore unexported fields.
   425  				if isUnexported {
   426  					continue
   427  				}
   428  
   429  				fieldName := field.Name()
   430  				if tag, ok := field.Tag().Lookup(tagKey); ok {
   431  					fieldName = tag
   432  				}
   433  
   434  				fields = append(fields, &structField{
   435  					Name:  fieldName,
   436  					Field: chain,
   437  				})
   438  			}
   439  		}
   440  	}
   441  
   442  	return &structDescriptor{
   443  		Type:   structType,
   444  		Fields: fields,
   445  	}
   446  }